@@ -787,13 +787,14 @@ def forward(self, y_pred, y_true, task_id=[], auto=True, **kwargs):
787
787
788
788
class meanAveragePrecisionLoss (torch .nn .Module ):
789
789
r"""
790
- Mean Average Precision loss based on squared-hinge surrogate loss to optimize mAP. This is an extension of :obj:`~libauc.losses.APLoss`.
790
+ Mean Average Precision loss based on squared-hinge surrogate loss to optimize mAP and mAP@k . This is an extension of :obj:`~libauc.losses.APLoss`.
791
791
792
792
Args:
793
793
data_len (int): total number of samples in the training dataset.
794
794
num_labels (int): number of unique labels(tasks) in the dataset.
795
795
margin (float, optional): margin for the squared-hinge surrogate loss (default: ``1.0``).
796
796
gamma (float, optional): parameter for the moving average estimator (default: ``0.9``).
797
+ top_k (int, optional): If given, only top k items will be considered for optimizing mAP@k.
797
798
surr_loss (str, optional): type of surrogate loss to use. Choices are 'squared_hinge', 'squared',
798
799
'logistic', 'barrier_hinge' (default: ``'squared_hinge'``).
799
800
@@ -819,6 +820,7 @@ def __init__(self,
819
820
num_labels ,
820
821
margin = 1.0 ,
821
822
gamma = 0.9 ,
823
+ top_k = - 1 ,
822
824
surr_loss = 'squared_hinge' ,
823
825
device = None ):
824
826
super (meanAveragePrecisionLoss , self ).__init__ ()
@@ -833,6 +835,7 @@ def __init__(self,
833
835
self .margin = margin
834
836
self .gamma = gamma
835
837
self .surrogate_loss = get_surrogate_loss (surr_loss )
838
+ self .top_k = top_k
836
839
837
840
def forward (self , y_pred , y_true , index , task_id = [], ** kwargs ):
838
841
y_pred = check_tensor_shape (y_pred , (- 1 , self .num_labels ))
@@ -856,6 +859,9 @@ def forward(self, y_pred, y_true, index, task_id=[], **kwargs):
856
859
self .u_all [idx ][index_i ] = (1 - self .gamma ) * self .u_all [idx ][index_i ] + self .gamma * (sur_loss .mean (1 , keepdim = True )).detach ()
857
860
self .u_pos [idx ][index_i ] = (1 - self .gamma ) * self .u_pos [idx ][index_i ] + self .gamma * (pos_sur_loss .mean (1 , keepdim = True )).detach ()
858
861
p_i = (self .u_pos [idx ][index_i ] - (self .u_all [idx ][index_i ]) * pos_mask ) / (self .u_all [idx ][index_i ] ** 2 ) # size of p_i: len(f_ps)* len(y_pred)
862
+ if self .top_k > - 1 :
863
+ selector = torch .sigmoid (self .top_k - sur_loss .sum (dim = 0 , keepdim = True ).clone ())
864
+ p_i *= selector
859
865
p_i .detach_ ()
860
866
loss = torch .mean (p_i * sur_loss )
861
867
total_loss += loss
0 commit comments