@@ -11,22 +11,28 @@ class L2Wrap(torch.autograd.Function):
1111 This version is memory-optimized by not storing the full logits tensor.
1212 """
1313 @staticmethod
14- def forward (ctx , loss , logits ):
15- ctx .save_for_backward (logits )
14+ def forward (ctx , loss , logits , l2_penalty_factor = 1e-4 ):
15+ """
16+ Forward pass for L2 penalty.
17+ Args:
18+ loss (torch.Tensor): The loss tensor.
19+ logits (torch.Tensor): Shape[B, T, V] The logits tensor.
20+ l2_penalty_factor (float): The factor for L2 penalty.
21+ """
22+ maxx , ids = torch .max (logits , dim = - 1 , keepdim = True )
23+ ctx .logits_shape = logits .shape
24+ factor = l2_penalty_factor / (logits .shape [0 ] * logits .shape [1 ])
25+ maxx = maxx * factor
26+ ctx .save_for_backward (maxx , ids )
1627 return loss
1728
1829 @staticmethod
1930 def backward (ctx , grad_output ):
20- logits = ctx .saved_tensors [0 ]
21-
22- factor = 1e-4 / (logits .shape [0 ] * logits .shape [1 ])
23- maxx , ids = torch .max (logits , - 1 , keepdim = True )
24-
25- glogits = torch .zeros_like (logits )
26- penalty_grad = maxx * factor
27- glogits .scatter_ (- 1 , ids , penalty_grad )
28-
29- return grad_output , glogits
31+ maxx , ids = ctx .saved_tensors
32+ glogits = torch .zeros (ctx .logits_shape , device = grad_output .device ,
33+ dtype = grad_output .dtype )
34+ glogits .scatter_ (- 1 , ids , maxx )
35+ return grad_output , glogits , None
3036
3137
3238l2_warp = L2Wrap .apply
0 commit comments