Skip to content

Commit 6c15512

Browse files
committed
Mark scale dimensions to have the same batch size as input
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent 1a86b4c commit 6c15512

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,23 @@ def _spyre_scaled_paged_compute_op(
317317
attn_kwargs["left_padded_prompt_mask"],
318318
attn_kwargs["block_table"],
319319
)
320+
321+
def __spyre_scaled_paged_validate_attn_kwargs_op(
322+
input_ids: torch.Tensor,
323+
position_ids: torch.Tensor,
324+
past_key_value_states: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None,
325+
**attn_kwargs,
326+
):
327+
__spyre_paged_validate_attn_kwargs_op(input_ids, position_ids, past_key_value_states, **attn_kwargs)
328+
329+
if past_key_value_states is not None:
330+
for k, v in past_key_value_states:
331+
assert isinstance(k, ScaledTensor)
332+
assert isinstance(v, ScaledTensor)
333+
334+
# assert that for each layer, the scales are per-sequence
335+
assert k._scale.shape[0] == input_ids.shape[0]
336+
assert v._scale.shape[0] == input_ids.shape[0]
320337

321338
register_attention_op(
322339
"spyre_paged_attn_fp8",
@@ -325,5 +342,5 @@ def _spyre_scaled_paged_compute_op(
325342
is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None)
326343
is None,
327344
compute_decode_op=_spyre_scaled_paged_compute_op,
328-
validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op,
345+
validate_attn_kwargs_op=__spyre_scaled_paged_validate_attn_kwargs_op,
329346
)

0 commit comments

Comments
 (0)