Skip to content

Commit 9103f54

Browse files
committed
Change paged FP8 prefill back to regular sdpa
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent f9ca98a commit 9103f54

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222

2323
# Local
24+
from fms.modules.attention import _sdpa_compute_op
2425
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor
2526
from fms_mo.prep import available_packages
2627
import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import
@@ -340,7 +341,7 @@ def __spyre_scaled_paged_validate_attn_kwargs_op(
340341
register_attention_op(
341342
"spyre_paged_attn_fp8",
342343
_spyre_scaled_paged_store_op,
343-
compute_op=_math_fp8_compute_op,
344+
compute_op=_sdpa_compute_op,
344345
is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None)
345346
is None,
346347
compute_decode_op=_spyre_scaled_paged_compute_op,

0 commit comments

Comments
 (0)