Skip to content

Commit fae7b28

Browse files
authored
Merge pull request #163 from ani300/mark-dynamic-fp8
fix: Mark FP8 scale to have the same batch size as input
2 parents 1a86b4c + c1a68d7 commit fae7b28

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,12 +318,31 @@ def _spyre_scaled_paged_compute_op(
318318
attn_kwargs["block_table"],
319319
)
320320

321+
def __spyre_scaled_paged_validate_attn_kwargs_op(
322+
input_ids: torch.Tensor,
323+
position_ids: torch.Tensor,
324+
past_key_value_states: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None,
325+
**attn_kwargs,
326+
):
327+
__spyre_paged_validate_attn_kwargs_op(
328+
input_ids, position_ids, past_key_value_states, **attn_kwargs
329+
)
330+
331+
if past_key_value_states is not None:
332+
for k, v in past_key_value_states:
333+
assert isinstance(k, ScaledTensor)
334+
assert isinstance(v, ScaledTensor)
335+
336+
# assert that for each layer, the scales are per-sequence
337+
assert k._scale.shape[0] == input_ids.shape[0]
338+
assert v._scale.shape[0] == input_ids.shape[0]
339+
321340
register_attention_op(
322341
"spyre_paged_attn_fp8",
323342
_spyre_scaled_paged_store_op,
324343
compute_op=_math_fp8_compute_op,
325344
is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None)
326345
is None,
327346
compute_decode_op=_spyre_scaled_paged_compute_op,
328-
validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op,
347+
validate_attn_kwargs_op=__spyre_scaled_paged_validate_attn_kwargs_op,
329348
)

0 commit comments

Comments
 (0)