1717from einops import reduce
1818
1919from magi_attention .common .enum import AttnSinkLayout
20+ from magi_attention .common .forward_meta import AttnForwardMeta
2021from magi_attention .meta .collection .calc_meta import AttnArg
2122from magi_attention .utils import make_attn_mask_from_ffa_args , to_higher_fp_dtype
2223
@@ -94,22 +95,30 @@ def sdpa_fwd_calc(
9495 v : torch .Tensor ,
9596 attn_bias : torch .Tensor ,
9697 softmax_scale : float ,
97- ) -> tuple [torch .Tensor , torch .Tensor ]:
98+ return_max_logits : bool = False ,
99+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor | None ]:
98100 attn_weight = to_higher_fp_dtype (
99101 q @ k .transpose (- 2 , - 1 ) * softmax_scale ,
100102 lowest_precision = torch .float32 ,
101103 )
102104 attn_weight += attn_bias
103105
104106 lse = attn_weight .logsumexp (dim = - 1 , keepdim = True )
107+ if return_max_logits :
108+ # compute per-head max logits over score matrix
109+ # attn_weight shape: [batch_size, num_heads, num_tokens_q, num_tokens_k]
110+ bsz , nhq = attn_weight .shape [:2 ]
111+ max_logits = attn_weight .view (bsz , nhq , - 1 ).max (dim = - 1 ).values .contiguous ()
112+ else :
113+ max_logits = None
105114
106115 # NOTE: pytorch softmax has many limitations and bugs
107116 # thus we use our own safe_softmax with lse involved
108117 attn_weight = safe_softmax (attn_weight , lse ).to (v .dtype )
109118
110119 out = attn_weight @ v
111120
112- return out , lse .squeeze (- 1 )
121+ return out , lse .squeeze (- 1 ), max_logits
113122
114123
115124def _sdpa_fwd (
@@ -119,14 +128,17 @@ def _sdpa_fwd(
119128 attn_mask : torch .Tensor | None = None ,
120129 is_causal : bool = False ,
121130 softmax_scale : float | None = None ,
122- ) -> tuple [torch .Tensor , torch .Tensor ]:
131+ return_max_logits : bool = False ,
132+ ) -> tuple [torch .Tensor , AttnForwardMeta ]:
123133 q , k , v , attn_bias , softmax_scale , _ = sdpa_fwd_preprocess (
124134 q , k , v , attn_mask , is_causal , softmax_scale
125135 )
126136
127- out , lse = sdpa_fwd_calc (q , k , v , attn_bias , softmax_scale )
137+ out , lse , max_logits = sdpa_fwd_calc (
138+ q , k , v , attn_bias , softmax_scale , return_max_logits
139+ )
128140
129- return out , lse
141+ return out , AttnForwardMeta ( lse = lse , max_logits = max_logits )
130142
131143
132144@torch .no_grad ()
@@ -139,7 +151,8 @@ def sdpa_fwd(
139151 softmax_scale : float | None = None ,
140152 softcap : float = 0.0 ,
141153 sink_layout : AttnSinkLayout = "sh" ,
142- ) -> tuple [torch .Tensor , torch .Tensor ]:
154+ return_max_logits : bool = False ,
155+ ) -> tuple [torch .Tensor , AttnForwardMeta ]:
143156 """SDPA forward function
144157
145158 Args:
@@ -163,12 +176,19 @@ def sdpa_fwd(
163176
164177 sink_layout (AttnSinkLayout, optional): sink layout. Defaults to "sh".
165178
179+ return_max_logits (bool, optional): whether to return max logits.
180+ Defaults to ``False``.
181+
166182 Returns:
167183 torch.Tensor: out with shape [num_tokens_q, num_heads_q, head_dim]
168184 or [batch_size, num_heads_q, num_tokens_q, head_dim]
169185
170- torch.Tensor: lse with shape [num_tokens_q, num_heads_q]
171- or [batch_size, num_heads_q, num_tokens_q]
186+ AttnForwardMeta: metadata for attention forward, including lse and max_logits.
187+ - lse (torch.Tensor): [num_tokens_q, num_heads_q]
188+ or [batch_size, num_heads_q, num_tokens_q]
189+ - max_logits (torch.Tensor or None): [num_heads_q]
190+ or [batch_size, num_heads_q]
191+ or None if return_max_logits is False
172192 """
173193 assert softcap == 0.0 , "non-zero softcap is not supported by now"
174194
@@ -187,17 +207,21 @@ def sdpa_fwd(
187207 device = torch .cuda .current_device (),
188208 )
189209
190- out , lse = _sdpa_fwd (
210+ out , meta = _sdpa_fwd (
191211 q ,
192212 k ,
193213 v ,
194214 attn_mask = attn_mask ,
195215 is_causal = False ,
196216 softmax_scale = softmax_scale ,
217+ return_max_logits = return_max_logits ,
197218 )
219+ lse , max_logits = meta .lse , meta .max_logits
198220
199221 if rearrange :
200222 out , lse = sdpa_fwd_out_lse_rearrange (out , lse )
223+ if max_logits is not None :
224+ max_logits = max_logits .squeeze (0 )
201225
202226 if sink is not None :
203227 assert rearrange
@@ -209,7 +233,7 @@ def sdpa_fwd(
209233 inplace = True ,
210234 )
211235
212- return out , lse
236+ return out , AttnForwardMeta ( lse = lse , max_logits = max_logits )
213237
214238
215239# ------------------ sdpa bwd ------------------ #
0 commit comments