1717from typing import Optional
1818
1919# Third Party
20- from packaging .version import Version
2120from torch import Tensor
2221import torch
2322import torch .nn .functional as F
3029# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3130
3231
33- if Version (torch .__version__ ) <= Version ("2.7" ):
34- # PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set,
35- # while for earlier versions we need a custom definition
36- def _scaled_mm_cpu_out (
37- mat1 : Tensor ,
38- mat2 : Tensor ,
39- scale1 : Tensor ,
40- scale2 : Tensor ,
41- bias : Optional [Tensor ] = None ,
42- scale_result : Optional [Tensor ] = None ,
43- out_dtype : Optional [torch .dtype ] = None ,
44- use_fast_accum : bool = False ,
45- * ,
46- out : Optional [Tensor ] = None ,
47- ) -> Tensor :
48- if out_dtype is None :
49- out_dtype = torch .float32
50- mat1 = (mat1 .to (dtype = out_dtype ) * scale1 ).to (dtype = out_dtype )
51- mat2 = (mat2 .to (dtype = out_dtype ) * scale2 ).to (dtype = out_dtype )
52-
53- if bias is not None :
54- ret = torch .addmm (bias , mat1 , mat2 ).to (dtype = out_dtype )
55- else :
56- ret = torch .mm (mat1 , mat2 ).to (dtype = out_dtype )
57-
58- if out is not None :
59- out .copy_ (ret )
60- return out
61- return ret
32+ # PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set.
33+ # This CPU implementation is not enough for our use case, so we still have to
34+ # keep our own custom version.
35+ def _scaled_mm_cpu_out (
36+ mat1 : Tensor ,
37+ mat2 : Tensor ,
38+ scale1 : Tensor ,
39+ scale2 : Tensor ,
40+ bias : Optional [Tensor ] = None ,
41+ scale_result : Optional [Tensor ] = None ,
42+ out_dtype : Optional [torch .dtype ] = None ,
43+ use_fast_accum : bool = False ,
44+ * ,
45+ out : Optional [Tensor ] = None ,
46+ ) -> Tensor :
47+ if out_dtype is None :
48+ out_dtype = torch .float32
49+ mat1 = (mat1 .to (dtype = out_dtype ) * scale1 ).to (dtype = out_dtype )
50+ mat2 = (mat2 .to (dtype = out_dtype ) * scale2 ).to (dtype = out_dtype )
51+
52+ if bias is not None :
53+ ret = torch .addmm (bias , mat1 , mat2 ).to (dtype = out_dtype )
54+ else :
55+ ret = torch .mm (mat1 , mat2 ).to (dtype = out_dtype )
6256
57+ if out is not None :
58+ out .copy_ (ret )
59+ return out
60+ return ret
61+
62+
63+ def _scaled_mm_cpu (
64+ mat1 : Tensor ,
65+ mat2 : Tensor ,
66+ scale1 : Tensor ,
67+ scale2 : Tensor ,
68+ bias : Optional [Tensor ] = None ,
69+ scale_result : Optional [Tensor ] = None ,
70+ out_dtype : Optional [torch .dtype ] = None ,
71+ use_fast_accum : bool = False ,
72+ ) -> Tensor :
73+ return _scaled_mm_cpu_out (
74+ mat1 ,
75+ mat2 ,
76+ scale1 ,
77+ scale2 ,
78+ bias ,
79+ scale_result ,
80+ out_dtype ,
81+ use_fast_accum ,
82+ out = None ,
83+ )
84+
85+
86+ if torch .__version__ >= "2.8" :
87+ DispatchKey = torch ._C .DispatchKey # type: ignore[attr-defined]
88+ torch .ops .aten ._scaled_mm .out .py_kernels [DispatchKey .CPU ] = _scaled_mm_cpu_out
89+ torch .ops .aten ._scaled_mm .default .py_kernels [DispatchKey .CPU ] = _scaled_mm_cpu
90+ else :
6391 torch .library .register_kernel (
6492 torch .ops .aten ._scaled_mm .out , "cpu" , _scaled_mm_cpu_out
6593 )
66-
67- @torch .library .register_kernel ("aten::_scaled_mm" , "cpu" )
68- def _scaled_mm_cpu (
69- mat1 : Tensor ,
70- mat2 : Tensor ,
71- scale1 : Tensor ,
72- scale2 : Tensor ,
73- bias : Optional [Tensor ] = None ,
74- scale_result : Optional [Tensor ] = None ,
75- out_dtype : Optional [torch .dtype ] = None ,
76- use_fast_accum : bool = False ,
77- ) -> Tensor :
78- return _scaled_mm_cpu_out (
79- mat1 ,
80- mat2 ,
81- scale1 ,
82- scale2 ,
83- bias ,
84- scale_result ,
85- out_dtype ,
86- use_fast_accum ,
87- out = None ,
88- )
94+ torch .library .register_kernel (
95+ torch .ops .aten ._scaled_mm .default , "cpu" , _scaled_mm_cpu
96+ )
8997
9098
9199@torch .library .custom_op ("spyre::scaled_bmm" , mutates_args = ())
@@ -115,7 +123,7 @@ def spyre_scaled_bmm(
115123 device = mat1 .device ,
116124 )
117125 for b_idx in range (mat1 .shape [0 ]):
118- out [b_idx ] = torch . _scaled_mm (
126+ out [b_idx ] = _scaled_mm_cpu_out (
119127 mat1 [b_idx ],
120128 mat2 [b_idx ],
121129 scale1 ,
@@ -218,6 +226,7 @@ def scaled_paged_attn_compute(
218226 num_kv_heads = value_cache .shape [2 ]
219227 head_size = value_cache .shape [3 ]
220228 block_size = value_cache .shape [1 ]
229+ seq_len_q = query .shape [1 ]
221230 num_seqs = query .shape [0 ]
222231
223232 block_tables_lst = block_table .cpu ().tolist ()
@@ -228,6 +237,7 @@ def scaled_paged_attn_compute(
228237 block_table = block_tables_lst [i ]
229238 start_pos = int (left_padded_prompt_mask [i ].item ())
230239 seq_len = int (seq_lens_lst [i ])
240+ seq_len_q_i = seq_len_q
231241
232242 keys_lst : list [torch .Tensor ] = []
233243 values_lst : list [torch .Tensor ] = []
@@ -243,13 +253,25 @@ def scaled_paged_attn_compute(
243253 values_lst .append (v )
244254 keys = torch .stack (keys_lst , dim = 0 )
245255 values = torch .stack (values_lst , dim = 0 )
256+ seq_len_kv = keys .shape [0 ]
257+
258+ # cut the pads for first prefill
259+ if q .shape [0 ] > seq_len_kv :
260+ seq_len_q_i = seq_len_kv
261+ q = q [- seq_len_kv :]
262+
246263 if num_kv_heads > 1 :
247264 # Handle MQA and GQA
248265 keys = torch .repeat_interleave (keys , num_query_heads // num_kv_heads , dim = 1 )
249266 values = torch .repeat_interleave (
250267 values , num_query_heads // num_kv_heads , dim = 1
251268 )
252269
270+ # Generate mask for prefix attention
271+ mask = torch .ones ((1 , 1 , seq_len_q_i , seq_len_kv ), dtype = torch .bool )
272+ mask [:, :, :, - seq_len_q_i :] = torch .tril (mask [:, :, :, - seq_len_q_i :])
273+ mask = torch .where (mask .logical_not (), - torch .inf , 0.0 )
274+
253275 out = F .scaled_dot_product_attention ( # noqa: E1102
254276 q .transpose (0 , 1 ).unsqueeze (0 ), # format for sdpa
255277 (keys .transpose (0 , 1 ).unsqueeze (0 ).to (dtype = q .dtype ) * key_scale [i ]).to (
@@ -258,12 +280,12 @@ def scaled_paged_attn_compute(
258280 (values .transpose (0 , 1 ).unsqueeze (0 ).to (dtype = q .dtype ) * value_scale [i ]).to (
259281 dtype = q .dtype
260282 ), # format for sdpa
261- is_causal = False , # decode assumes no causal mask
283+ attn_mask = mask , # decode assumes no causal mask
262284 scale = scale ,
263285 )
264286
265- out = out .view (num_query_heads , head_size )
266- output [i ]. copy_ ( out , non_blocking = True )
287+ out = out .transpose ( 1 , 2 ). view (seq_len_q_i , num_query_heads , head_size )
288+ output [i ][ - seq_len_q_i :] = out
267289 return output
268290
269291
0 commit comments