3030# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3131
3232
33- if Version (torch .__version__ ) <= Version ("2.7" ):
34- # PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set,
35- # while for earlier versions we need a custom definition
36- def _scaled_mm_cpu_out (
37- mat1 : Tensor ,
38- mat2 : Tensor ,
39- scale1 : Tensor ,
40- scale2 : Tensor ,
41- bias : Optional [Tensor ] = None ,
42- scale_result : Optional [Tensor ] = None ,
43- out_dtype : Optional [torch .dtype ] = None ,
44- use_fast_accum : bool = False ,
45- * ,
46- out : Optional [Tensor ] = None ,
47- ) -> Tensor :
48- if out_dtype is None :
49- out_dtype = torch .float32
50- mat1 = (mat1 .to (dtype = out_dtype ) * scale1 ).to (dtype = out_dtype )
51- mat2 = (mat2 .to (dtype = out_dtype ) * scale2 ).to (dtype = out_dtype )
52-
53- if bias is not None :
54- ret = torch .addmm (bias , mat1 , mat2 ).to (dtype = out_dtype )
55- else :
56- ret = torch .mm (mat1 , mat2 ).to (dtype = out_dtype )
57-
58- if out is not None :
59- out .copy_ (ret )
60- return out
61- return ret
62-
63- torch .library .register_kernel (
64- torch .ops .aten ._scaled_mm .out , "cpu" , _scaled_mm_cpu_out
65- )
33+ # PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set.
34+ # This CPU implementation is not enough for our use case, so we still have to
35+ # keep our own custom version.
36+ def _scaled_mm_cpu_out (
37+ mat1 : Tensor ,
38+ mat2 : Tensor ,
39+ scale1 : Tensor ,
40+ scale2 : Tensor ,
41+ bias : Optional [Tensor ] = None ,
42+ scale_result : Optional [Tensor ] = None ,
43+ out_dtype : Optional [torch .dtype ] = None ,
44+ use_fast_accum : bool = False ,
45+ * ,
46+ out : Optional [Tensor ] = None ,
47+ ) -> Tensor :
48+ if out_dtype is None :
49+ out_dtype = torch .float32
50+ mat1 = (mat1 .to (dtype = out_dtype ) * scale1 ).to (dtype = out_dtype )
51+ mat2 = (mat2 .to (dtype = out_dtype ) * scale2 ).to (dtype = out_dtype )
6652
67- @torch .library .register_kernel ("aten::_scaled_mm" , "cpu" )
68- def _scaled_mm_cpu (
69- mat1 : Tensor ,
70- mat2 : Tensor ,
71- scale1 : Tensor ,
72- scale2 : Tensor ,
73- bias : Optional [Tensor ] = None ,
74- scale_result : Optional [Tensor ] = None ,
75- out_dtype : Optional [torch .dtype ] = None ,
76- use_fast_accum : bool = False ,
77- ) -> Tensor :
78- return _scaled_mm_cpu_out (
79- mat1 ,
80- mat2 ,
81- scale1 ,
82- scale2 ,
83- bias ,
84- scale_result ,
85- out_dtype ,
86- use_fast_accum ,
87- out = None ,
88- )
53+ if bias is not None :
54+ ret = torch .addmm (bias , mat1 , mat2 ).to (dtype = out_dtype )
55+ else :
56+ ret = torch .mm (mat1 , mat2 ).to (dtype = out_dtype )
57+
58+ if out is not None :
59+ out .copy_ (ret )
60+ return out
61+ return ret
62+
63+
64+ torch .library .register_kernel (torch .ops .aten ._scaled_mm .out , "cpu" , _scaled_mm_cpu_out )
65+
66+
67+ @torch .library .register_kernel ("aten::_scaled_mm" , "cpu" )
68+ def _scaled_mm_cpu (
69+ mat1 : Tensor ,
70+ mat2 : Tensor ,
71+ scale1 : Tensor ,
72+ scale2 : Tensor ,
73+ bias : Optional [Tensor ] = None ,
74+ scale_result : Optional [Tensor ] = None ,
75+ out_dtype : Optional [torch .dtype ] = None ,
76+ use_fast_accum : bool = False ,
77+ ) -> Tensor :
78+ return _scaled_mm_cpu_out (
79+ mat1 ,
80+ mat2 ,
81+ scale1 ,
82+ scale2 ,
83+ bias ,
84+ scale_result ,
85+ out_dtype ,
86+ use_fast_accum ,
87+ out = None ,
88+ )
8989
9090
9191@torch .library .custom_op ("spyre::scaled_bmm" , mutates_args = ())
@@ -218,6 +218,7 @@ def scaled_paged_attn_compute(
218218 num_kv_heads = value_cache .shape [2 ]
219219 head_size = value_cache .shape [3 ]
220220 block_size = value_cache .shape [1 ]
221+ seq_len_q = query .shape [1 ]
221222 num_seqs = query .shape [0 ]
222223
223224 block_tables_lst = block_table .cpu ().tolist ()
@@ -228,6 +229,7 @@ def scaled_paged_attn_compute(
228229 block_table = block_tables_lst [i ]
229230 start_pos = int (left_padded_prompt_mask [i ].item ())
230231 seq_len = int (seq_lens_lst [i ])
232+ seq_len_q_i = seq_len_q
231233
232234 keys_lst : list [torch .Tensor ] = []
233235 values_lst : list [torch .Tensor ] = []
@@ -243,13 +245,25 @@ def scaled_paged_attn_compute(
243245 values_lst .append (v )
244246 keys = torch .stack (keys_lst , dim = 0 )
245247 values = torch .stack (values_lst , dim = 0 )
248+ seq_len_kv = keys .shape [0 ]
249+
250+ # cut the pads for first prefill
251+ if q .shape [0 ] > seq_len_kv :
252+ seq_len_q_i = seq_len_kv
253+ q = q [- seq_len_kv :]
254+
246255 if num_kv_heads > 1 :
247256 # Handle MQA and GQA
248257 keys = torch .repeat_interleave (keys , num_query_heads // num_kv_heads , dim = 1 )
249258 values = torch .repeat_interleave (
250259 values , num_query_heads // num_kv_heads , dim = 1
251260 )
252261
262+ # Generate mask for prefix attention
263+ mask = torch .ones ((1 , 1 , seq_len_q_i , seq_len_kv ), dtype = torch .bool )
264+ mask [:, :, :, - seq_len_q_i :] = torch .tril (mask [:, :, :, - seq_len_q_i :])
265+ mask = torch .where (mask .logical_not (), - torch .inf , 0.0 )
266+
253267 out = F .scaled_dot_product_attention ( # noqa: E1102
254268 q .transpose (0 , 1 ).unsqueeze (0 ), # format for sdpa
255269 (keys .transpose (0 , 1 ).unsqueeze (0 ).to (dtype = q .dtype ) * key_scale [i ]).to (
@@ -258,12 +272,12 @@ def scaled_paged_attn_compute(
258272 (values .transpose (0 , 1 ).unsqueeze (0 ).to (dtype = q .dtype ) * value_scale [i ]).to (
259273 dtype = q .dtype
260274 ), # format for sdpa
261- is_causal = False , # decode assumes no causal mask
275+ attn_mask = mask , # decode assumes no causal mask
262276 scale = scale ,
263277 )
264278
265- out = out .view (num_query_heads , head_size )
266- output [i ]. copy_ ( out , non_blocking = True )
279+ out = out .transpose ( 1 , 2 ). view (seq_len_q_i , num_query_heads , head_size )
280+ output [i ][ - seq_len_q_i :] = out
267281 return output
268282
269283
0 commit comments