1414"""FMS registration of attention BMM operation using torch-registered scaled BMM."""
1515
1616# Standard
17- from typing import NotRequired , Unpack
17+ from typing import NotRequired , Optional , Unpack
1818import math
1919
2020# Third Party
2323 _sdpa_update_attn_kwargs ,
2424 register_attention_op ,
2525)
26+ from fms .utils .spyre .paged import (
27+ SpyrePagedAttentionKwargs ,
28+ __spyre_paged_validate_attn_kwargs_op ,
29+ )
2630import torch
2731
2832# Local
@@ -46,7 +50,7 @@ class MathFP8AttentionKwargs(AttentionKwargs):
4650
4751def _construct_fp8_cache (tensor : torch .Tensor , scale : torch .Tensor ) -> ScaledTensor :
4852 """Construct the custom object to save KV cache with its scales."""
49- return ScaledTensor (tensor , scale )
53+ return ScaledTensor (tensor , scale , True )
5054
5155
5256def _math_fp8_store_op (
@@ -58,13 +62,15 @@ def _math_fp8_store_op(
5862) -> tuple [ScaledTensor , ScaledTensor , ScaledTensor , ScaledTensor ]:
5963 """Implement math of KV cache storing."""
6064
65+ # Grab scale from kv-cache if already there, compute dynamically otherwise
6166 if isinstance (key_cache , ScaledTensor ) and isinstance (value_cache , ScaledTensor ):
6267 k_scale = key_cache ._scale
6368 v_scale = value_cache ._scale
6469 else :
6570 k_scale = (torch .abs (keys ).max () / K_RANGE ).to (dtype = torch .float32 )
6671 v_scale = (torch .abs (values ).max () / V_RANGE ).to (dtype = torch .float32 )
6772
73+ # Scale kv tensors for storage
6874 keys = (keys / k_scale ).to (torch .float8_e4m3fn ).transpose (2 , 1 )
6975 values = (values / v_scale ).to (torch .float8_e4m3fn ).transpose (2 , 1 )
7076
@@ -83,6 +89,7 @@ def _math_fp8_store_op(
8389 key_cache ,
8490 value_cache ,
8591 )
92+ # If it's a new kv cache, ensure it's contiguous for spyre use cases
8693 keys = _construct_fp8_cache (keys .contiguous (), k_scale )
8794 values = _construct_fp8_cache (values .contiguous (), v_scale )
8895 return (keys , values , keys , values )
@@ -98,35 +105,40 @@ def _math_fp8_compute_op(
98105 scale_factor : float | None ,
99106 ** attn_kwargs : Unpack [MathFP8AttentionKwargs ],
100107) -> torch .Tensor :
101- """Implement computation of attention BMM, leveraging the custom scaled attention
102- BMM op that was pre-registered for torch.compile."""
108+ """Implement computation of scaled dot product attention, leveraging
109+ the custom scaled BMM op that was pre-registered for torch.compile."""
103110
104111 orig_dtype = query .dtype
105112
113+ # Scaling the Q tensor is optional
106114 q_scale = torch .tensor (1.0 , dtype = torch .float32 , device = query .device )
107115 if attn_kwargs .get ("do_scale_q" , False ):
108116 q_scale .copy_ (torch .abs (query ).max () / Q_RANGE )
109117 query = query / q_scale
110118
111119 query = query .to (torch .float8_e4m3fn ).transpose (2 , 1 )
112120
121+ # Grab kv cache and deal with cases where no store op was run
113122 if isinstance (key_cache , ScaledTensor ) and isinstance (value_cache , ScaledTensor ):
123+ # Store op was run
114124 k_scale = key_cache ._scale
115125 v_scale = value_cache ._scale
116126 key_cache = key_cache ._data
117127 value_cache = value_cache ._data
118128 else :
129+ # Store op wasn't run (e.g. encoders, use_cache=False)
119130 k_scale = (torch .abs (key_cache ).max () / K_RANGE ).to (dtype = torch .float32 )
120131 v_scale = (torch .abs (value_cache ).max () / V_RANGE ).to (dtype = torch .float32 )
121132 key_cache = (key_cache / k_scale ).to (torch .float8_e4m3fn )
122133 value_cache = (value_cache / v_scale ).to (torch .float8_e4m3fn )
123134
124- # no longer transposing prior to store, so need to check this in case of no cache
135+ # If store wasn't run, we need to transpose the tensors here
125136 # TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here
126137 if key_cache .shape [1 ] != kvheads and key_cache .shape [2 ] == kvheads :
127138 key_cache = key_cache .transpose (2 , 1 )
128139 value_cache = value_cache .transpose (2 , 1 )
129140
141+ # Most of the code that follows is a copy of Pytorch SDPA, with fp8 additions
130142 mask = attn_kwargs .get ("mask" , None )
131143 if mask is not None :
132144 # Our expected mask format is bs x q_len x k_len, so to make it broadcastable
@@ -187,3 +199,86 @@ def _math_fp8_compute_op(
187199 _math_fp8_compute_op ,
188200 update_attn_kwargs_op = _sdpa_update_attn_kwargs ,
189201)
202+
203+
204+ def _spyre_scaled_paged_store_op (
205+ keys : torch .Tensor ,
206+ values : torch .Tensor ,
207+ key_cache : Optional [torch .Tensor ],
208+ value_cache : Optional [torch .Tensor ],
209+ ** attn_kwargs : Unpack [SpyrePagedAttentionKwargs ],
210+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
211+ # For paged store, we must have pre-allocated the kv-cache
212+ assert key_cache is not None and isinstance (
213+ key_cache , ScaledTensor
214+ ), "kv cache must be preallocated"
215+ assert value_cache is not None and isinstance (
216+ value_cache , ScaledTensor
217+ ), "kv cache must be preallocated"
218+ if not key_cache ._scaled :
219+ key_cache ._scale = (torch .abs (keys ).max () / 200.0 ).to (dtype = torch .float32 )
220+ value_cache ._scale = (torch .abs (values ).max () / 100.0 ).to (dtype = torch .float32 )
221+
222+ result_key_cache_data , result_value_cache_data = (
223+ torch .ops .spyre .scaled_paged_attn_store (
224+ keys ,
225+ values ,
226+ key_cache ._data ,
227+ value_cache ._data ,
228+ key_cache ._scale ,
229+ value_cache ._scale ,
230+ attn_kwargs ["slot_mapping" ],
231+ )
232+ )
233+
234+ result_key_cache = _construct_fp8_cache (result_key_cache_data , key_cache ._scale )
235+ result_value_cache = _construct_fp8_cache (
236+ result_value_cache_data , value_cache ._scale
237+ )
238+
239+ # for prefill, we want to return the original keys/values
240+ if attn_kwargs .get ("block_table" , None ) is None :
241+ return keys , values , result_key_cache , result_value_cache
242+ return (
243+ result_key_cache ,
244+ result_value_cache ,
245+ result_key_cache ,
246+ result_value_cache ,
247+ )
248+
249+
250+ def _spyre_scaled_paged_compute_op (
251+ query : torch .Tensor ,
252+ key_cache : torch .Tensor ,
253+ value_cache : torch .Tensor ,
254+ nheads : int ,
255+ kvheads : int ,
256+ p_dropout : float ,
257+ scale_factor : Optional [float ],
258+ ** attn_kwargs ,
259+ ) -> torch .Tensor :
260+ assert isinstance (key_cache , ScaledTensor ), "kv cache must be scaled"
261+ assert isinstance (value_cache , ScaledTensor ), "kv cache must be scaled"
262+ if scale_factor is None :
263+ scale_factor = 1 / math .sqrt (query .shape [- 1 ])
264+ return torch .ops .spyre .scaled_paged_attn_compute (
265+ query ,
266+ key_cache ._data ,
267+ value_cache ._data ,
268+ key_cache ._scale ,
269+ value_cache ._scale ,
270+ scale_factor ,
271+ attn_kwargs ["current_tkv_mask" ],
272+ attn_kwargs ["left_padded_prompt_mask" ],
273+ attn_kwargs ["block_table" ],
274+ )
275+
276+
277+ register_attention_op (
278+ "spyre_paged_attn_fp8" ,
279+ _spyre_scaled_paged_store_op ,
280+ compute_op = _math_fp8_compute_op ,
281+ is_prefill_op = lambda ** attn_kwargs : attn_kwargs .get ("block_table" , None ) is None ,
282+ compute_decode_op = _spyre_scaled_paged_compute_op ,
283+ validate_attn_kwargs_op = __spyre_paged_validate_attn_kwargs_op ,
284+ )
0 commit comments