Skip to content

Commit 761e743

Browse files
authored
Max logits fix (#239)
* refactor max_logits code and add max_logits feature for sdpa backend and test_pipeline_sdpa * add max_logits_print in test_case of test_pipeline_sdpa * enable sdpa backend when test max_logits in test_dist_attn * fix unit test bugs * refactor code * update test_dist_attn
1 parent f46d7c5 commit 761e743

File tree

10 files changed

+168
-27
lines changed

10 files changed

+168
-27
lines changed

magi_attention/common/forward_meta.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,16 @@
1919

2020
@dataclass
2121
class AttnForwardMeta:
22+
"""Attention forward metadata.
23+
24+
Attributes:
25+
lse: Log-sum-exp of the attention weights. In a distributed setting, this is a
26+
local tensor where each device holds the LSE computed from its local query
27+
shards.
28+
max_logits: Maximum logits per query head. In a distributed setting,
29+
this is a replicated tensor where each device holds the global maximum
30+
computed across the entire sequence, ensuring consistency across all devices.
31+
"""
32+
2233
lse: torch.Tensor | None
2334
max_logits: torch.Tensor | None

magi_attention/dist_attn_runtime_mgr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def calc_attn(
150150
sink: torch.Tensor | None = None,
151151
softmax_scale: float | None = None,
152152
softcap: float = 0.0,
153+
return_max_logits: bool = False,
153154
) -> tuple[torch.Tensor, AttnForwardMeta]:
154155
return dist_attn_func(
155156
q=q,
@@ -159,6 +160,7 @@ def calc_attn(
159160
sink=sink,
160161
softmax_scale=softmax_scale,
161162
softcap=softcap,
163+
return_max_logits=return_max_logits,
162164
)
163165

164166
def get_xattn_args(

magi_attention/functional/dist_attn.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,13 @@ def apply_fwd_partial_attn(
230230
q=q,
231231
sink=sink,
232232
)
233-
return partial_out, AttnForwardMeta(lse=partial_lse, max_logits=None)
233+
partial_max_logits = self._init_max_logits_skipped_host_stage(
234+
q=q,
235+
return_max_logits=return_max_logits,
236+
)
237+
return partial_out, AttnForwardMeta(
238+
lse=partial_lse, max_logits=partial_max_logits
239+
)
234240
return None, None
235241

236242
# attention forward pass
@@ -1056,17 +1062,17 @@ def _launch_attn_fwd_kernel(
10561062
return_max_logits: bool = False,
10571063
) -> tuple[torch.Tensor, AttnForwardMeta]:
10581064
if return_max_logits:
1059-
assert not (
1060-
self.use_sdpa_backend or self.use_fa4_backend
1061-
), "SDPA and FA4 backend do not support return max logits"
1065+
assert (
1066+
not self.use_fa4_backend
1067+
), "FA4 backend does not support return max logits"
10621068
with nvtx.add_nvtx_event(
10631069
f"attn-fwd: "
10641070
f"{attn_arg.total_area=} | "
10651071
f"{attn_arg.q_ranges=} | "
10661072
f"{attn_arg.k_ranges=}"
10671073
):
10681074
if self.use_sdpa_backend:
1069-
partial_out, partial_lse = sdpa_fwd(
1075+
partial_out, meta = sdpa_fwd(
10701076
q=q,
10711077
k=k,
10721078
v=v,
@@ -1077,8 +1083,12 @@ def _launch_attn_fwd_kernel(
10771083
softmax_scale=softmax_scale,
10781084
softcap=softcap,
10791085
sink_layout="sh",
1086+
return_max_logits=return_max_logits,
10801087
)
1081-
meta = AttnForwardMeta(lse=partial_lse, max_logits=None)
1088+
if return_max_logits and max_logits_acc is not None:
1089+
assert meta.max_logits is not None
1090+
torch.maximum(max_logits_acc, meta.max_logits, out=max_logits_acc)
1091+
meta.max_logits = max_logits_acc
10821092
elif self.use_fa4_backend:
10831093
partial_out, partial_lse = fa4_fwd(
10841094
q=q,
@@ -1920,6 +1930,20 @@ def _init_out_lse_skipped_host_stage(
19201930

19211931
return out, lse
19221932

1933+
def _init_max_logits_skipped_host_stage(
1934+
self,
1935+
q: torch.Tensor,
1936+
return_max_logits: bool,
1937+
) -> torch.Tensor | None:
1938+
if return_max_logits:
1939+
return torch.full(
1940+
(q.size(1),), # [nhq]
1941+
fill_value=float("-inf"),
1942+
dtype=q.dtype,
1943+
device=q.device,
1944+
)
1945+
return None
1946+
19231947
def _init_dq_dkv_dsink_skipped_host_stage(
19241948
self,
19251949
qo_do: FusedOrTupleTensor,

magi_attention/functional/flex_flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def _flex_flash_attn_forward(
397397
assert q.size(1) <= 128, (
398398
f"num_qheads ({q.size(1)}) must be <= 128 because the epilogue shmem "
399399
"for max_logits reduction is fixed at 128 in C++ code. You can increase "
400-
"the shmem size by increasing the `smem_max_logitss` in `epilogue_fwd.hpp`."
400+
"the shmem size by increasing the `smem_max_logits` in `epilogue_fwd.hpp`."
401401
)
402402

403403
if ref_block_size is not None:

magi_attention/functional/sdpa.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from einops import reduce
1818

1919
from magi_attention.common.enum import AttnSinkLayout
20+
from magi_attention.common.forward_meta import AttnForwardMeta
2021
from magi_attention.meta.collection.calc_meta import AttnArg
2122
from magi_attention.utils import make_attn_mask_from_ffa_args, to_higher_fp_dtype
2223

@@ -94,22 +95,30 @@ def sdpa_fwd_calc(
9495
v: torch.Tensor,
9596
attn_bias: torch.Tensor,
9697
softmax_scale: float,
97-
) -> tuple[torch.Tensor, torch.Tensor]:
98+
return_max_logits: bool = False,
99+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
98100
attn_weight = to_higher_fp_dtype(
99101
q @ k.transpose(-2, -1) * softmax_scale,
100102
lowest_precision=torch.float32,
101103
)
102104
attn_weight += attn_bias
103105

104106
lse = attn_weight.logsumexp(dim=-1, keepdim=True)
107+
if return_max_logits:
108+
# compute per-head max logits over score matrix
109+
# attn_weight shape: [batch_size, num_heads, num_tokens_q, num_tokens_k]
110+
bsz, nhq = attn_weight.shape[:2]
111+
max_logits = attn_weight.view(bsz, nhq, -1).max(dim=-1).values.contiguous()
112+
else:
113+
max_logits = None
105114

106115
# NOTE: pytorch softmax has many limitations and bugs
107116
# thus we use our own safe_softmax with lse involved
108117
attn_weight = safe_softmax(attn_weight, lse).to(v.dtype)
109118

110119
out = attn_weight @ v
111120

112-
return out, lse.squeeze(-1)
121+
return out, lse.squeeze(-1), max_logits
113122

114123

115124
def _sdpa_fwd(
@@ -119,14 +128,17 @@ def _sdpa_fwd(
119128
attn_mask: torch.Tensor | None = None,
120129
is_causal: bool = False,
121130
softmax_scale: float | None = None,
122-
) -> tuple[torch.Tensor, torch.Tensor]:
131+
return_max_logits: bool = False,
132+
) -> tuple[torch.Tensor, AttnForwardMeta]:
123133
q, k, v, attn_bias, softmax_scale, _ = sdpa_fwd_preprocess(
124134
q, k, v, attn_mask, is_causal, softmax_scale
125135
)
126136

127-
out, lse = sdpa_fwd_calc(q, k, v, attn_bias, softmax_scale)
137+
out, lse, max_logits = sdpa_fwd_calc(
138+
q, k, v, attn_bias, softmax_scale, return_max_logits
139+
)
128140

129-
return out, lse
141+
return out, AttnForwardMeta(lse=lse, max_logits=max_logits)
130142

131143

132144
@torch.no_grad()
@@ -139,7 +151,8 @@ def sdpa_fwd(
139151
softmax_scale: float | None = None,
140152
softcap: float = 0.0,
141153
sink_layout: AttnSinkLayout = "sh",
142-
) -> tuple[torch.Tensor, torch.Tensor]:
154+
return_max_logits: bool = False,
155+
) -> tuple[torch.Tensor, AttnForwardMeta]:
143156
"""SDPA forward function
144157
145158
Args:
@@ -163,12 +176,19 @@ def sdpa_fwd(
163176
164177
sink_layout (AttnSinkLayout, optional): sink layout. Defaults to "sh".
165178
179+
return_max_logits (bool, optional): whether to return max logits.
180+
Defaults to ``False``.
181+
166182
Returns:
167183
torch.Tensor: out with shape [num_tokens_q, num_heads_q, head_dim]
168184
or [batch_size, num_heads_q, num_tokens_q, head_dim]
169185
170-
torch.Tensor: lse with shape [num_tokens_q, num_heads_q]
171-
or [batch_size, num_heads_q, num_tokens_q]
186+
AttnForwardMeta: metadata for attention forward, including lse and max_logits.
187+
- lse (torch.Tensor): [num_tokens_q, num_heads_q]
188+
or [batch_size, num_heads_q, num_tokens_q]
189+
- max_logits (torch.Tensor or None): [num_heads_q]
190+
or [batch_size, num_heads_q]
191+
or None if return_max_logits is False
172192
"""
173193
assert softcap == 0.0, "non-zero softcap is not supported by now"
174194

@@ -187,17 +207,21 @@ def sdpa_fwd(
187207
device=torch.cuda.current_device(),
188208
)
189209

190-
out, lse = _sdpa_fwd(
210+
out, meta = _sdpa_fwd(
191211
q,
192212
k,
193213
v,
194214
attn_mask=attn_mask,
195215
is_causal=False,
196216
softmax_scale=softmax_scale,
217+
return_max_logits=return_max_logits,
197218
)
219+
lse, max_logits = meta.lse, meta.max_logits
198220

199221
if rearrange:
200222
out, lse = sdpa_fwd_out_lse_rearrange(out, lse)
223+
if max_logits is not None:
224+
max_logits = max_logits.squeeze(0)
201225

202226
if sink is not None:
203227
assert rearrange
@@ -209,7 +233,7 @@ def sdpa_fwd(
209233
inplace=True,
210234
)
211235

212-
return out, lse
236+
return out, AttnForwardMeta(lse=lse, max_logits=max_logits)
213237

214238

215239
# ------------------ sdpa bwd ------------------ #

magi_attention/testing/ref_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ def ref_attn_func(
700700
# maybe cast input to high precision
701701
org_dtype = q.dtype
702702
lse_dtype = max_fp_dtype(org_dtype, torch.float32)
703+
max_logits_dtype = max_fp_dtype(org_dtype, torch.float32)
703704
if high_precision: # use fp64 as ground-truth
704705
q = q.to(torch.float64)
705706
k = k.to(torch.float64)
@@ -743,6 +744,6 @@ def ref_attn_func(
743744
if return_max_logits:
744745
assert meta is not None # mypy
745746
assert meta.max_logits is not None # mypy
746-
meta.max_logits = meta.max_logits.to(torch.float32)
747+
meta.max_logits = meta.max_logits.to(max_logits_dtype)
747748

748749
return out, meta

tests/test_api/test_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def world_size(self) -> int:
123123

124124
@property
125125
def timeout(self) -> int:
126-
return 600
126+
return 1200
127127

128128
@property
129129
def seed(self) -> int:

tests/test_attn/test_dist_attn.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def world_size(self) -> int:
118118

119119
@property
120120
def timeout(self) -> int:
121-
return 1200
121+
return 1800
122122

123123
@property
124124
def seed(self) -> int:
@@ -163,10 +163,6 @@ def test_full_attn(
163163
if use_native_grpcoll:
164164
return
165165

166-
# sdpa backend do not support return max logits
167-
if return_max_logits and use_sdpa_backend:
168-
return
169-
170166
# switch the env flags
171167
switch_back = switch_envvars(
172168
envvar_name_list=[
@@ -341,7 +337,7 @@ def test_full_attn(
341337
local_max_logits,
342338
total_max_logits_ref,
343339
atol=EPSILON,
344-
rtol=1e-3,
340+
rtol=1e-2 if use_sdpa_backend else 1e-3,
345341
mismatch_threshold=0.01,
346342
test_case="max_logits",
347343
)
@@ -373,7 +369,7 @@ def test_full_attn(
373369
assert_close(
374370
total_dsink,
375371
total_dsink_ref,
376-
atol=1e-3,
372+
atol=5e-3,
377373
rtol=0.1,
378374
mismatch_threshold=max(1 / (seqlen_sink * nhq), 5e-2),
379375
test_case="dsink",

0 commit comments

Comments
 (0)