@@ -787,13 +787,14 @@ def forward(self, y_pred, y_true, task_id=[], auto=True, **kwargs):
787787
788788class meanAveragePrecisionLoss (torch .nn .Module ):
789789 r"""
790- Mean Average Precision loss based on squared-hinge surrogate loss to optimize mAP. This is an extension of :obj:`~libauc.losses.APLoss`.
790+ Mean Average Precision loss based on squared-hinge surrogate loss to optimize mAP and mAP@k . This is an extension of :obj:`~libauc.losses.APLoss`.
791791
792792 Args:
793793 data_len (int): total number of samples in the training dataset.
794794 num_labels (int): number of unique labels(tasks) in the dataset.
795795 margin (float, optional): margin for the squared-hinge surrogate loss (default: ``1.0``).
796796 gamma (float, optional): parameter for the moving average estimator (default: ``0.9``).
797+ top_k (int, optional): If given, only top k items will be considered for optimizing mAP@k.
797798 surr_loss (str, optional): type of surrogate loss to use. Choices are 'squared_hinge', 'squared',
798799 'logistic', 'barrier_hinge' (default: ``'squared_hinge'``).
799800
@@ -819,6 +820,7 @@ def __init__(self,
819820 num_labels ,
820821 margin = 1.0 ,
821822 gamma = 0.9 ,
823+ top_k = - 1 ,
822824 surr_loss = 'squared_hinge' ,
823825 device = None ):
824826 super (meanAveragePrecisionLoss , self ).__init__ ()
@@ -833,6 +835,7 @@ def __init__(self,
833835 self .margin = margin
834836 self .gamma = gamma
835837 self .surrogate_loss = get_surrogate_loss (surr_loss )
838+ self .top_k = top_k
836839
837840 def forward (self , y_pred , y_true , index , task_id = [], ** kwargs ):
838841 y_pred = check_tensor_shape (y_pred , (- 1 , self .num_labels ))
@@ -856,6 +859,9 @@ def forward(self, y_pred, y_true, index, task_id=[], **kwargs):
856859 self .u_all [idx ][index_i ] = (1 - self .gamma ) * self .u_all [idx ][index_i ] + self .gamma * (sur_loss .mean (1 , keepdim = True )).detach ()
857860 self .u_pos [idx ][index_i ] = (1 - self .gamma ) * self .u_pos [idx ][index_i ] + self .gamma * (pos_sur_loss .mean (1 , keepdim = True )).detach ()
858861 p_i = (self .u_pos [idx ][index_i ] - (self .u_all [idx ][index_i ]) * pos_mask ) / (self .u_all [idx ][index_i ] ** 2 ) # size of p_i: len(f_ps)* len(y_pred)
862+ if self .top_k > - 1 :
863+ selector = torch .sigmoid (self .top_k - sur_loss .sum (dim = 0 , keepdim = True ).clone ())
864+ p_i *= selector
859865 p_i .detach_ ()
860866 loss = torch .mean (p_i * sur_loss )
861867 total_loss += loss
0 commit comments