@@ -317,6 +317,23 @@ def _spyre_scaled_paged_compute_op(
317317 attn_kwargs ["left_padded_prompt_mask" ],
318318 attn_kwargs ["block_table" ],
319319 )
320+
321+ def __spyre_scaled_paged_validate_attn_kwargs_op (
322+ input_ids : torch .Tensor ,
323+ position_ids : torch .Tensor ,
324+ past_key_value_states : Optional [list [tuple [torch .Tensor , torch .Tensor ]]] = None ,
325+ ** attn_kwargs ,
326+ ):
327+ __spyre_paged_validate_attn_kwargs_op (input_ids , position_ids , past_key_value_states , ** attn_kwargs )
328+
329+ if past_key_value_states is not None :
330+ for k , v in past_key_value_states :
331+ assert isinstance (k , ScaledTensor )
332+ assert isinstance (v , ScaledTensor )
333+
334+ # assert that for each layer, the scales are per-sequence
335+ assert k ._scale .shape [0 ] == input_ids .shape [0 ]
336+ assert v ._scale .shape [0 ] == input_ids .shape [0 ]
320337
321338 register_attention_op (
322339 "spyre_paged_attn_fp8" ,
@@ -325,5 +342,5 @@ def _spyre_scaled_paged_compute_op(
325342 is_prefill_op = lambda ** attn_kwargs : attn_kwargs .get ("block_table" , None )
326343 is None ,
327344 compute_decode_op = _spyre_scaled_paged_compute_op ,
328- validate_attn_kwargs_op = __spyre_paged_validate_attn_kwargs_op ,
345+ validate_attn_kwargs_op = __spyre_scaled_paged_validate_attn_kwargs_op ,
329346 )
0 commit comments