@@ -318,12 +318,31 @@ def _spyre_scaled_paged_compute_op(
318318 attn_kwargs ["block_table" ],
319319 )
320320
321+ def __spyre_scaled_paged_validate_attn_kwargs_op (
322+ input_ids : torch .Tensor ,
323+ position_ids : torch .Tensor ,
324+ past_key_value_states : Optional [list [tuple [torch .Tensor , torch .Tensor ]]] = None ,
325+ ** attn_kwargs ,
326+ ):
327+ __spyre_paged_validate_attn_kwargs_op (
328+ input_ids , position_ids , past_key_value_states , ** attn_kwargs
329+ )
330+
331+ if past_key_value_states is not None :
332+ for k , v in past_key_value_states :
333+ assert isinstance (k , ScaledTensor )
334+ assert isinstance (v , ScaledTensor )
335+
336+ # assert that for each layer, the scales are per-sequence
337+ assert k ._scale .shape [0 ] == input_ids .shape [0 ]
338+ assert v ._scale .shape [0 ] == input_ids .shape [0 ]
339+
321340 register_attention_op (
322341 "spyre_paged_attn_fp8" ,
323342 _spyre_scaled_paged_store_op ,
324343 compute_op = _math_fp8_compute_op ,
325344 is_prefill_op = lambda ** attn_kwargs : attn_kwargs .get ("block_table" , None )
326345 is None ,
327346 compute_decode_op = _spyre_scaled_paged_compute_op ,
328- validate_attn_kwargs_op = __spyre_paged_validate_attn_kwargs_op ,
347+ validate_attn_kwargs_op = __spyre_scaled_paged_validate_attn_kwargs_op ,
329348 )
0 commit comments