Skip to content

Commit 19d091b

Browse files
authored
Merge pull request #61 from PenGuln/1.3.1
update map at k loss
2 parents 85068de + f33ada9 commit 19d091b

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

libauc/losses/auc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,13 +787,14 @@ def forward(self, y_pred, y_true, task_id=[], auto=True, **kwargs):
787787

788788
class meanAveragePrecisionLoss(torch.nn.Module):
789789
r"""
790-
Mean Average Precision loss based on squared-hinge surrogate loss to optimize mAP. This is an extension of :obj:`~libauc.losses.APLoss`.
790+
Mean Average Precision loss based on squared-hinge surrogate loss to optimize mAP and mAP@k. This is an extension of :obj:`~libauc.losses.APLoss`.
791791
792792
Args:
793793
data_len (int): total number of samples in the training dataset.
794794
num_labels (int): number of unique labels(tasks) in the dataset.
795795
margin (float, optional): margin for the squared-hinge surrogate loss (default: ``1.0``).
796796
gamma (float, optional): parameter for the moving average estimator (default: ``0.9``).
797+
top_k (int, optional): If given, only top k items will be considered for optimizing mAP@k.
797798
surr_loss (str, optional): type of surrogate loss to use. Choices are 'squared_hinge', 'squared',
798799
'logistic', 'barrier_hinge' (default: ``'squared_hinge'``).
799800
@@ -819,6 +820,7 @@ def __init__(self,
819820
num_labels,
820821
margin=1.0,
821822
gamma=0.9,
823+
top_k=-1,
822824
surr_loss='squared_hinge',
823825
device=None):
824826
super(meanAveragePrecisionLoss, self).__init__()
@@ -833,6 +835,7 @@ def __init__(self,
833835
self.margin = margin
834836
self.gamma = gamma
835837
self.surrogate_loss = get_surrogate_loss(surr_loss)
838+
self.top_k = top_k
836839

837840
def forward(self, y_pred, y_true, index, task_id=[], **kwargs):
838841
y_pred = check_tensor_shape(y_pred, (-1, self.num_labels))
@@ -856,6 +859,9 @@ def forward(self, y_pred, y_true, index, task_id=[], **kwargs):
856859
self.u_all[idx][index_i] = (1 - self.gamma) * self.u_all[idx][index_i] + self.gamma * (sur_loss.mean(1, keepdim=True)).detach()
857860
self.u_pos[idx][index_i] = (1 - self.gamma) * self.u_pos[idx][index_i] + self.gamma * (pos_sur_loss.mean(1, keepdim=True)).detach()
858861
p_i = (self.u_pos[idx][index_i] - (self.u_all[idx][index_i]) * pos_mask) / (self.u_all[idx][index_i] ** 2) # size of p_i: len(f_ps)* len(y_pred)
862+
if self.top_k > -1:
863+
selector = torch.sigmoid(self.top_k - sur_loss.sum(dim=0, keepdim=True).clone())
864+
p_i *= selector
859865
p_i.detach_()
860866
loss = torch.mean(p_i * sur_loss)
861867
total_loss += loss

0 commit comments

Comments
 (0)