diff --git a/center_loss.py b/center_loss.py index 9e87012..aaa1aff 100644 --- a/center_loss.py +++ b/center_loss.py @@ -1,5 +1,9 @@ +import warnings + import torch -import torch.nn as nn +from torch import nn +from torch.nn import functional as F + class CenterLoss(nn.Module): """Center loss. @@ -11,16 +15,16 @@ class CenterLoss(nn.Module): num_classes (int): number of classes. feat_dim (int): feature dimension. """ - def __init__(self, num_classes=10, feat_dim=2, use_gpu=True): + def __init__(self, num_classes: int = 10, feat_dim: int = 2, use_gpu: bool = None, clamp: int = 1e-12): super(CenterLoss, self).__init__() + if use_gpu is not None: + warnings.warning(f"Ignoring explicitly set {use_gpu=}. Move the model via .to(device)") self.num_classes = num_classes self.feat_dim = feat_dim self.use_gpu = use_gpu + self.clamp = clamp - if self.use_gpu: - self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) - else: - self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) + self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) def forward(self, x, labels): """ @@ -28,17 +32,5 @@ def forward(self, x, labels): x: feature matrix with shape (batch_size, feat_dim). labels: ground truth labels with shape (batch_size). """ - batch_size = x.size(0) - distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ - torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() - distmat.addmm_(1, -2, x, self.centers.t()) - - classes = torch.arange(self.num_classes).long() - if self.use_gpu: classes = classes.cuda() - labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) - mask = labels.eq(classes.expand(batch_size, self.num_classes)) - - dist = distmat * mask.float() - loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size - - return loss + centers = torch.index_select(self.centers, 0, labels.view(-1)) # [Classes, Features] (gather) [Batch] -> [Batch, Features] + return F.mse_loss(x, centers) * self.feat_dim # mean across all axes except features