Skip to content

Commit e3f1310

Browse files
authored
Merge pull request #191 from foundation-model-stack/chunked_fp8
fix: Fixes for paged fp8 attention with chunked prefill
2 parents b3437b7 + 1991d09 commit e3f1310

File tree

1 file changed

+79
-57
lines changed

1 file changed

+79
-57
lines changed

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 79 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Optional
1818

1919
# Third Party
20-
from packaging.version import Version
2120
from torch import Tensor
2221
import torch
2322
import torch.nn.functional as F
@@ -30,62 +29,71 @@
3029
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3130

3231

33-
if Version(torch.__version__) <= Version("2.7"):
34-
# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set,
35-
# while for earlier versions we need a custom definition
36-
def _scaled_mm_cpu_out(
37-
mat1: Tensor,
38-
mat2: Tensor,
39-
scale1: Tensor,
40-
scale2: Tensor,
41-
bias: Optional[Tensor] = None,
42-
scale_result: Optional[Tensor] = None,
43-
out_dtype: Optional[torch.dtype] = None,
44-
use_fast_accum: bool = False,
45-
*,
46-
out: Optional[Tensor] = None,
47-
) -> Tensor:
48-
if out_dtype is None:
49-
out_dtype = torch.float32
50-
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
51-
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
52-
53-
if bias is not None:
54-
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
55-
else:
56-
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
57-
58-
if out is not None:
59-
out.copy_(ret)
60-
return out
61-
return ret
32+
# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set.
33+
# This CPU implementation is not enough for our use case, so we still have to
34+
# keep our own custom version.
35+
def _scaled_mm_cpu_out(
36+
mat1: Tensor,
37+
mat2: Tensor,
38+
scale1: Tensor,
39+
scale2: Tensor,
40+
bias: Optional[Tensor] = None,
41+
scale_result: Optional[Tensor] = None,
42+
out_dtype: Optional[torch.dtype] = None,
43+
use_fast_accum: bool = False,
44+
*,
45+
out: Optional[Tensor] = None,
46+
) -> Tensor:
47+
if out_dtype is None:
48+
out_dtype = torch.float32
49+
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
50+
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
51+
52+
if bias is not None:
53+
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
54+
else:
55+
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
6256

57+
if out is not None:
58+
out.copy_(ret)
59+
return out
60+
return ret
61+
62+
63+
def _scaled_mm_cpu(
64+
mat1: Tensor,
65+
mat2: Tensor,
66+
scale1: Tensor,
67+
scale2: Tensor,
68+
bias: Optional[Tensor] = None,
69+
scale_result: Optional[Tensor] = None,
70+
out_dtype: Optional[torch.dtype] = None,
71+
use_fast_accum: bool = False,
72+
) -> Tensor:
73+
return _scaled_mm_cpu_out(
74+
mat1,
75+
mat2,
76+
scale1,
77+
scale2,
78+
bias,
79+
scale_result,
80+
out_dtype,
81+
use_fast_accum,
82+
out=None,
83+
)
84+
85+
86+
if torch.__version__ >= "2.8":
87+
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
88+
torch.ops.aten._scaled_mm.out.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu_out
89+
torch.ops.aten._scaled_mm.default.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu
90+
else:
6391
torch.library.register_kernel(
6492
torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out
6593
)
66-
67-
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
68-
def _scaled_mm_cpu(
69-
mat1: Tensor,
70-
mat2: Tensor,
71-
scale1: Tensor,
72-
scale2: Tensor,
73-
bias: Optional[Tensor] = None,
74-
scale_result: Optional[Tensor] = None,
75-
out_dtype: Optional[torch.dtype] = None,
76-
use_fast_accum: bool = False,
77-
) -> Tensor:
78-
return _scaled_mm_cpu_out(
79-
mat1,
80-
mat2,
81-
scale1,
82-
scale2,
83-
bias,
84-
scale_result,
85-
out_dtype,
86-
use_fast_accum,
87-
out=None,
88-
)
94+
torch.library.register_kernel(
95+
torch.ops.aten._scaled_mm.default, "cpu", _scaled_mm_cpu
96+
)
8997

9098

9199
@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
@@ -115,7 +123,7 @@ def spyre_scaled_bmm(
115123
device=mat1.device,
116124
)
117125
for b_idx in range(mat1.shape[0]):
118-
out[b_idx] = torch._scaled_mm(
126+
out[b_idx] = _scaled_mm_cpu_out(
119127
mat1[b_idx],
120128
mat2[b_idx],
121129
scale1,
@@ -218,6 +226,7 @@ def scaled_paged_attn_compute(
218226
num_kv_heads = value_cache.shape[2]
219227
head_size = value_cache.shape[3]
220228
block_size = value_cache.shape[1]
229+
seq_len_q = query.shape[1]
221230
num_seqs = query.shape[0]
222231

223232
block_tables_lst = block_table.cpu().tolist()
@@ -228,6 +237,7 @@ def scaled_paged_attn_compute(
228237
block_table = block_tables_lst[i]
229238
start_pos = int(left_padded_prompt_mask[i].item())
230239
seq_len = int(seq_lens_lst[i])
240+
seq_len_q_i = seq_len_q
231241

232242
keys_lst: list[torch.Tensor] = []
233243
values_lst: list[torch.Tensor] = []
@@ -243,13 +253,25 @@ def scaled_paged_attn_compute(
243253
values_lst.append(v)
244254
keys = torch.stack(keys_lst, dim=0)
245255
values = torch.stack(values_lst, dim=0)
256+
seq_len_kv = keys.shape[0]
257+
258+
# cut the pads for first prefill
259+
if q.shape[0] > seq_len_kv:
260+
seq_len_q_i = seq_len_kv
261+
q = q[-seq_len_kv:]
262+
246263
if num_kv_heads > 1:
247264
# Handle MQA and GQA
248265
keys = torch.repeat_interleave(keys, num_query_heads // num_kv_heads, dim=1)
249266
values = torch.repeat_interleave(
250267
values, num_query_heads // num_kv_heads, dim=1
251268
)
252269

270+
# Generate mask for prefix attention
271+
mask = torch.ones((1, 1, seq_len_q_i, seq_len_kv), dtype=torch.bool)
272+
mask[:, :, :, -seq_len_q_i:] = torch.tril(mask[:, :, :, -seq_len_q_i:])
273+
mask = torch.where(mask.logical_not(), -torch.inf, 0.0)
274+
253275
out = F.scaled_dot_product_attention( # noqa: E1102
254276
q.transpose(0, 1).unsqueeze(0), # format for sdpa
255277
(keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale[i]).to(
@@ -258,12 +280,12 @@ def scaled_paged_attn_compute(
258280
(values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale[i]).to(
259281
dtype=q.dtype
260282
), # format for sdpa
261-
is_causal=False, # decode assumes no causal mask
283+
attn_mask=mask, # decode assumes no causal mask
262284
scale=scale,
263285
)
264286

265-
out = out.view(num_query_heads, head_size)
266-
output[i].copy_(out, non_blocking=True)
287+
out = out.transpose(1, 2).view(seq_len_q_i, num_query_heads, head_size)
288+
output[i][-seq_len_q_i:] = out
267289
return output
268290

269291

0 commit comments

Comments
 (0)