@@ -36,13 +36,13 @@ class SparsemaxFunction(Function):
3636 """
3737
3838 @staticmethod
39- def forward (ctx , input , dim = - 1 ):
39+ def forward (ctx , input_ , dim = - 1 ):
4040 """
4141 Forward pass of sparsemax: a normalizing, sparse transformation.
4242
4343 Parameters
4444 ----------
45- input : torch.Tensor
45+ input_ : torch.Tensor
4646 The input tensor on which sparsemax will be applied.
4747 dim : int, optional
4848 Dimension along which to apply sparsemax. Default is -1.
@@ -53,10 +53,10 @@ def forward(ctx, input, dim=-1):
5353 A tensor with the same shape as the input, with sparsemax applied.
5454 """
5555 ctx .dim = dim
56- max_val , _ = input .max (dim = dim , keepdim = True )
57- input -= max_val # Numerical stability trick, as with softmax.
58- tau , supp_size = SparsemaxFunction ._threshold_and_support (input , dim = dim )
59- output = torch .clamp (input - tau , min = 0 )
56+ max_val , _ = input_ .max (dim = dim , keepdim = True )
57+ input_ -= max_val # Numerical stability trick, as with softmax.
58+ tau , supp_size = SparsemaxFunction ._threshold_and_support (input_ , dim = dim )
59+ output = torch .clamp (input_ - tau , min = 0 )
6060 ctx .save_for_backward (supp_size , output )
6161 return output
6262
@@ -86,13 +86,13 @@ def backward(ctx, grad_output): # type: ignore
8686 return grad_input , None
8787
8888 @staticmethod
89- def _threshold_and_support (input , dim = - 1 ):
89+ def _threshold_and_support (input_ , dim = - 1 ):
9090 """
9191 Computes the threshold and support for sparsemax.
9292
9393 Parameters
9494 ----------
95- input : torch.Tensor
95+ input_ : torch.Tensor
9696 The input tensor on which to compute the threshold and support.
9797 dim : int, optional
9898 Dimension along which to compute the threshold and support. Default is -1.
@@ -103,14 +103,14 @@ def _threshold_and_support(input, dim=-1):
103103 - torch.Tensor : The threshold value for sparsemax.
104104 - torch.Tensor : The support size tensor.
105105 """
106- input_srt , _ = torch .sort (input , descending = True , dim = dim )
106+ input_srt , _ = torch .sort (input_ , descending = True , dim = dim )
107107 input_cumsum = input_srt .cumsum (dim ) - 1
108- rhos = _make_ix_like (input , dim )
108+ rhos = _make_ix_like (input_ , dim )
109109 support = rhos * input_srt > input_cumsum
110110
111111 support_size = support .sum (dim = dim ).unsqueeze (dim )
112112 tau = input_cumsum .gather (dim , support_size - 1 )
113- tau /= support_size .to (input .dtype )
113+ tau /= support_size .to (input_ .dtype )
114114 return tau , support_size
115115
116116
0 commit comments