@@ -42,6 +42,7 @@ class MathFP8AttentionKwargs(AttentionKwargs):
4242
4343 mask : NotRequired [torch .Tensor ]
4444 do_scale_q : bool
45+ do_scaled_bmm : bool
4546 is_causal_mask : bool
4647
4748 # TODO: Figure out better scales for AIU? These come from vLLM
@@ -110,14 +111,17 @@ def _math_fp8_compute_op(
110111 the custom scaled BMM op that was pre-registered for torch.compile."""
111112
112113 orig_dtype = query .dtype
114+ do_scaled_bmm = attn_kwargs .get ("do_scaled_bmm" , False )
113115
114- # Scaling the Q tensor is optional
115- q_scale = torch .tensor (1.0 , dtype = torch .float32 , device = query .device )
116- if attn_kwargs .get ("do_scale_q" , False ):
117- q_scale .copy_ (torch .abs (query ).max () / Q_RANGE )
118- query = query / q_scale
116+ if do_scaled_bmm :
117+ # Scaling the Q tensor is optional
118+ q_scale = torch .tensor (1.0 , dtype = torch .float32 , device = query .device )
119+ if attn_kwargs .get ("do_scale_q" , False ):
120+ q_scale .copy_ (torch .abs (query ).max () / Q_RANGE )
121+ query = query / q_scale
119122
120- query = query .to (torch .float8_e4m3fn ).transpose (2 , 1 )
123+ query = query .to (torch .float8_e4m3fn )
124+ query = query .transpose (2 , 1 )
121125
122126 # Grab kv cache and deal with cases where no store op was run
123127 if isinstance (key_cache , ScaledTensor ) and isinstance (
@@ -175,17 +179,22 @@ def _math_fp8_compute_op(
175179 query .size (- 3 ) // value_cache .size (- 3 ), - 3
176180 )
177181
178- attn_weight = (
179- torch .ops .spyre .scaled_bmm (
180- query ,
181- key_cache .transpose (- 2 , - 1 ),
182- q_scale ,
183- k_scale ,
184- out_dtype = orig_dtype ,
185- use_fast_accum = True ,
182+ if do_scaled_bmm :
183+ attn_weight = (
184+ torch .ops .spyre .scaled_bmm (
185+ query ,
186+ key_cache .transpose (- 2 , - 1 ),
187+ q_scale ,
188+ k_scale ,
189+ out_dtype = orig_dtype ,
190+ use_fast_accum = True ,
191+ )
192+ * scale_factor
186193 )
187- * scale_factor
188- )
194+ else :
195+ key_t = (key_cache .to (dtype = orig_dtype ) * k_scale ).transpose (- 2 , - 1 )
196+ attn_weight = query @ key_t
197+ attn_weight *= scale_factor
189198 attn_weight += attn_bias
190199 attn_weight = torch .softmax (attn_weight , dim = - 1 )
191200 attn_weight = torch .dropout (attn_weight , p_dropout , train = True )
0 commit comments