Skip to content

Commit 10bbc83

Browse files
committed
Fixes for paged fp8 attention with chunked prefill
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent 65acdec commit 10bbc83

File tree

1 file changed

+72
-58
lines changed

1 file changed

+72
-58
lines changed

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 72 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -30,62 +30,62 @@
3030
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3131

3232

33-
if Version(torch.__version__) <= Version("2.7"):
34-
# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set,
35-
# while for earlier versions we need a custom definition
36-
def _scaled_mm_cpu_out(
37-
mat1: Tensor,
38-
mat2: Tensor,
39-
scale1: Tensor,
40-
scale2: Tensor,
41-
bias: Optional[Tensor] = None,
42-
scale_result: Optional[Tensor] = None,
43-
out_dtype: Optional[torch.dtype] = None,
44-
use_fast_accum: bool = False,
45-
*,
46-
out: Optional[Tensor] = None,
47-
) -> Tensor:
48-
if out_dtype is None:
49-
out_dtype = torch.float32
50-
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
51-
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
52-
53-
if bias is not None:
54-
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
55-
else:
56-
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
57-
58-
if out is not None:
59-
out.copy_(ret)
60-
return out
61-
return ret
62-
63-
torch.library.register_kernel(
64-
torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out
65-
)
33+
# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set.
34+
# This CPU implementation is not enough for our use case, so we still have to
35+
# keep our own custom version.
36+
def _scaled_mm_cpu_out(
37+
mat1: Tensor,
38+
mat2: Tensor,
39+
scale1: Tensor,
40+
scale2: Tensor,
41+
bias: Optional[Tensor] = None,
42+
scale_result: Optional[Tensor] = None,
43+
out_dtype: Optional[torch.dtype] = None,
44+
use_fast_accum: bool = False,
45+
*,
46+
out: Optional[Tensor] = None,
47+
) -> Tensor:
48+
if out_dtype is None:
49+
out_dtype = torch.float32
50+
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
51+
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
6652

67-
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
68-
def _scaled_mm_cpu(
69-
mat1: Tensor,
70-
mat2: Tensor,
71-
scale1: Tensor,
72-
scale2: Tensor,
73-
bias: Optional[Tensor] = None,
74-
scale_result: Optional[Tensor] = None,
75-
out_dtype: Optional[torch.dtype] = None,
76-
use_fast_accum: bool = False,
77-
) -> Tensor:
78-
return _scaled_mm_cpu_out(
79-
mat1,
80-
mat2,
81-
scale1,
82-
scale2,
83-
bias,
84-
scale_result,
85-
out_dtype,
86-
use_fast_accum,
87-
out=None,
88-
)
53+
if bias is not None:
54+
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
55+
else:
56+
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
57+
58+
if out is not None:
59+
out.copy_(ret)
60+
return out
61+
return ret
62+
63+
64+
torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out)
65+
66+
67+
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
68+
def _scaled_mm_cpu(
69+
mat1: Tensor,
70+
mat2: Tensor,
71+
scale1: Tensor,
72+
scale2: Tensor,
73+
bias: Optional[Tensor] = None,
74+
scale_result: Optional[Tensor] = None,
75+
out_dtype: Optional[torch.dtype] = None,
76+
use_fast_accum: bool = False,
77+
) -> Tensor:
78+
return _scaled_mm_cpu_out(
79+
mat1,
80+
mat2,
81+
scale1,
82+
scale2,
83+
bias,
84+
scale_result,
85+
out_dtype,
86+
use_fast_accum,
87+
out=None,
88+
)
8989

9090

9191
@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
@@ -218,6 +218,7 @@ def scaled_paged_attn_compute(
218218
num_kv_heads = value_cache.shape[2]
219219
head_size = value_cache.shape[3]
220220
block_size = value_cache.shape[1]
221+
seq_len_q = query.shape[1]
221222
num_seqs = query.shape[0]
222223

223224
block_tables_lst = block_table.cpu().tolist()
@@ -228,6 +229,7 @@ def scaled_paged_attn_compute(
228229
block_table = block_tables_lst[i]
229230
start_pos = int(left_padded_prompt_mask[i].item())
230231
seq_len = int(seq_lens_lst[i])
232+
seq_len_q_i = seq_len_q
231233

232234
keys_lst: list[torch.Tensor] = []
233235
values_lst: list[torch.Tensor] = []
@@ -243,13 +245,25 @@ def scaled_paged_attn_compute(
243245
values_lst.append(v)
244246
keys = torch.stack(keys_lst, dim=0)
245247
values = torch.stack(values_lst, dim=0)
248+
seq_len_kv = keys.shape[0]
249+
250+
# cut the pads for first prefill
251+
if q.shape[0] > seq_len_kv:
252+
seq_len_q_i = seq_len_kv
253+
q = q[-seq_len_kv:]
254+
246255
if num_kv_heads > 1:
247256
# Handle MQA and GQA
248257
keys = torch.repeat_interleave(keys, num_query_heads // num_kv_heads, dim=1)
249258
values = torch.repeat_interleave(
250259
values, num_query_heads // num_kv_heads, dim=1
251260
)
252261

262+
# Generate mask for prefix attention
263+
mask = torch.ones((1, 1, seq_len_q_i, seq_len_kv), dtype=torch.bool)
264+
mask[:, :, :, -seq_len_q_i:] = torch.tril(mask[:, :, :, -seq_len_q_i:])
265+
mask = torch.where(mask.logical_not(), -torch.inf, 0.0)
266+
253267
out = F.scaled_dot_product_attention( # noqa: E1102
254268
q.transpose(0, 1).unsqueeze(0), # format for sdpa
255269
(keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale[i]).to(
@@ -258,12 +272,12 @@ def scaled_paged_attn_compute(
258272
(values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale[i]).to(
259273
dtype=q.dtype
260274
), # format for sdpa
261-
is_causal=False, # decode assumes no causal mask
275+
attn_mask=mask, # decode assumes no causal mask
262276
scale=scale,
263277
)
264278

265-
out = out.view(num_query_heads, head_size)
266-
output[i].copy_(out, non_blocking=True)
279+
out = out.transpose(1, 2).view(seq_len_q_i, num_query_heads, head_size)
280+
output[i][-seq_len_q_i:] = out
267281
return output
268282

269283

0 commit comments

Comments
 (0)