We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f9ca98a commit 9103f54Copy full SHA for 9103f54
fms_mo/aiu_addons/fp8/fp8_attn.py
@@ -21,6 +21,7 @@
21
import torch
22
23
# Local
24
+from fms.modules.attention import _sdpa_compute_op
25
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor
26
from fms_mo.prep import available_packages
27
import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import
@@ -340,7 +341,7 @@ def __spyre_scaled_paged_validate_attn_kwargs_op(
340
341
register_attention_op(
342
"spyre_paged_attn_fp8",
343
_spyre_scaled_paged_store_op,
- compute_op=_math_fp8_compute_op,
344
+ compute_op=_sdpa_compute_op,
345
is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None)
346
is None,
347
compute_decode_op=_spyre_scaled_paged_compute_op,
0 commit comments