From 8d3d7aa9af0d21de5d8c9fdcad6984f2cbfcef08 Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Feb 2023 12:25:13 +0100 Subject: [PATCH 1/3] style: port to torch 1.13 --- center_loss.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/center_loss.py b/center_loss.py index 9e87012..4dddfe7 100644 --- a/center_loss.py +++ b/center_loss.py @@ -1,5 +1,9 @@ +import warnings + import torch -import torch.nn as nn +from torch import nn +from torch.nn import functional as F + class CenterLoss(nn.Module): """Center loss. @@ -11,16 +15,16 @@ class CenterLoss(nn.Module): num_classes (int): number of classes. feat_dim (int): feature dimension. """ - def __init__(self, num_classes=10, feat_dim=2, use_gpu=True): + def __init__(self, num_classes: int = 10, feat_dim: int = 2, use_gpu: bool = None, clamp: int = 1e-12): super(CenterLoss, self).__init__() + if use_gpu is not None: + warnings.warning(f"Ignoring explicitly set {use_gpu=}. Move the model via .to(device)") self.num_classes = num_classes self.feat_dim = feat_dim self.use_gpu = use_gpu + self.clamp = clamp - if self.use_gpu: - self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) - else: - self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) + self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) def forward(self, x, labels): """ @@ -28,17 +32,9 @@ def forward(self, x, labels): x: feature matrix with shape (batch_size, feat_dim). labels: ground truth labels with shape (batch_size). """ - batch_size = x.size(0) - distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ - torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() - distmat.addmm_(1, -2, x, self.centers.t()) - - classes = torch.arange(self.num_classes).long() - if self.use_gpu: classes = classes.cuda() - labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) - mask = labels.eq(classes.expand(batch_size, self.num_classes)) - - dist = distmat * mask.float() - loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size - - return loss + centers = centers.t() + distmat = x.square().sum(dim=1, keepdim=True) + centers.square().sum(dim=0, keepdim=True) + # B F @ B C -> Gather C -> B F @ F + distmat = distmat - 2 * x @ centers + dist = torch.gather(distmat, 1, labels.view(-1, 1)) + return dist.clamp(min=self.clamp, max=1 / self.clamp).mean() From b676bb9903ed0e7d8159bf0679316f8b5adb10e7 Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Feb 2023 12:34:27 +0100 Subject: [PATCH 2/3] style: index_select centers + binomial expansion --- center_loss.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/center_loss.py b/center_loss.py index 4dddfe7..511bd9c 100644 --- a/center_loss.py +++ b/center_loss.py @@ -32,9 +32,6 @@ def forward(self, x, labels): x: feature matrix with shape (batch_size, feat_dim). labels: ground truth labels with shape (batch_size). """ - centers = centers.t() - distmat = x.square().sum(dim=1, keepdim=True) + centers.square().sum(dim=0, keepdim=True) - # B F @ B C -> Gather C -> B F @ F - distmat = distmat - 2 * x @ centers - dist = torch.gather(distmat, 1, labels.view(-1, 1)) + centers = torch.index_select(self.centers, 0, labels.view(-1)) # [Classes, Features] (gather) [Batch] -> [Batch, Features] + dist = (x - centers).square().sum(dim=1, keepdim=False) return dist.clamp(min=self.clamp, max=1 / self.clamp).mean() From eb24046a21ca441810b8dd9b3b830756e2f9d08f Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Feb 2023 12:38:37 +0100 Subject: [PATCH 3/3] style: just use mse loss --- center_loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/center_loss.py b/center_loss.py index 511bd9c..aaa1aff 100644 --- a/center_loss.py +++ b/center_loss.py @@ -33,5 +33,4 @@ def forward(self, x, labels): labels: ground truth labels with shape (batch_size). """ centers = torch.index_select(self.centers, 0, labels.view(-1)) # [Classes, Features] (gather) [Batch] -> [Batch, Features] - dist = (x - centers).square().sum(dim=1, keepdim=False) - return dist.clamp(min=self.clamp, max=1 / self.clamp).mean() + return F.mse_loss(x, centers) * self.feat_dim # mean across all axes except features