Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions center_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import warnings

import torch
import torch.nn as nn
from torch import nn
from torch.nn import functional as F


class CenterLoss(nn.Module):
"""Center loss.
Expand All @@ -11,34 +15,22 @@ class CenterLoss(nn.Module):
num_classes (int): number of classes.
feat_dim (int): feature dimension.
"""
def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
def __init__(self, num_classes: int = 10, feat_dim: int = 2, use_gpu: bool = None, clamp: int = 1e-12):
super(CenterLoss, self).__init__()
if use_gpu is not None:
warnings.warning(f"Ignoring explicitly set {use_gpu=}. Move the model via .to(device)")
self.num_classes = num_classes
self.feat_dim = feat_dim
self.use_gpu = use_gpu
self.clamp = clamp

if self.use_gpu:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
else:
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

def forward(self, x, labels):
"""
Args:
x: feature matrix with shape (batch_size, feat_dim).
labels: ground truth labels with shape (batch_size).
"""
batch_size = x.size(0)
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat.addmm_(1, -2, x, self.centers.t())

classes = torch.arange(self.num_classes).long()
if self.use_gpu: classes = classes.cuda()
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
mask = labels.eq(classes.expand(batch_size, self.num_classes))

dist = distmat * mask.float()
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

return loss
centers = torch.index_select(self.centers, 0, labels.view(-1)) # [Classes, Features] (gather) [Batch] -> [Batch, Features]
return F.mse_loss(x, centers) * self.feat_dim # mean across all axes except features