Skip to content

Commit 3a373ff

Browse files
committed
Add choice for scaled bmm
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent bdf1cf2 commit 3a373ff

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class MathFP8AttentionKwargs(AttentionKwargs):
4242

4343
mask: NotRequired[torch.Tensor]
4444
do_scale_q: bool
45+
do_scaled_bmm: bool
4546
is_causal_mask: bool
4647

4748
# TODO: Figure out better scales for AIU? These come from vLLM
@@ -110,14 +111,17 @@ def _math_fp8_compute_op(
110111
the custom scaled BMM op that was pre-registered for torch.compile."""
111112

112113
orig_dtype = query.dtype
114+
do_scaled_bmm = attn_kwargs.get("do_scaled_bmm", False)
113115

114-
# Scaling the Q tensor is optional
115-
q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device)
116-
if attn_kwargs.get("do_scale_q", False):
117-
q_scale.copy_(torch.abs(query).max() / Q_RANGE)
118-
query = query / q_scale
116+
if do_scaled_bmm:
117+
# Scaling the Q tensor is optional
118+
q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device)
119+
if attn_kwargs.get("do_scale_q", False):
120+
q_scale.copy_(torch.abs(query).max() / Q_RANGE)
121+
query = query / q_scale
119122

120-
query = query.to(torch.float8_e4m3fn).transpose(2, 1)
123+
query = query.to(torch.float8_e4m3fn)
124+
query = query.transpose(2, 1)
121125

122126
# Grab kv cache and deal with cases where no store op was run
123127
if isinstance(key_cache, ScaledTensor) and isinstance(
@@ -175,17 +179,22 @@ def _math_fp8_compute_op(
175179
query.size(-3) // value_cache.size(-3), -3
176180
)
177181

178-
attn_weight = (
179-
torch.ops.spyre.scaled_bmm(
180-
query,
181-
key_cache.transpose(-2, -1),
182-
q_scale,
183-
k_scale,
184-
out_dtype=orig_dtype,
185-
use_fast_accum=True,
182+
if do_scaled_bmm:
183+
attn_weight = (
184+
torch.ops.spyre.scaled_bmm(
185+
query,
186+
key_cache.transpose(-2, -1),
187+
q_scale,
188+
k_scale,
189+
out_dtype=orig_dtype,
190+
use_fast_accum=True,
191+
)
192+
* scale_factor
186193
)
187-
* scale_factor
188-
)
194+
else:
195+
key_t = (key_cache.to(dtype=orig_dtype) * k_scale).transpose(-2, -1)
196+
attn_weight = query @ key_t
197+
attn_weight *= scale_factor
189198
attn_weight += attn_bias
190199
attn_weight = torch.softmax(attn_weight, dim=-1)
191200
attn_weight = torch.dropout(attn_weight, p_dropout, train=True)

0 commit comments

Comments
 (0)