1414"""FMS registration of attention BMM operation using torch-registered scaled BMM."""
1515
1616# Standard
17- from importlib.util import find_spec
1817from typing import NotRequired, Unpack
1918import math
2019
2423 _sdpa_update_attn_kwargs,
2524 register_attention_op,
2625)
27- from torch import Tensor
2826import torch
2927
3028# Local
31- import fms_mo.aiu_addons.fp8.fp8_aiu_op # pylint: disable=unused-import
32-
33- if find_spec("torchao"):
34- TORCHAO_INSTALLED = True
35- # Third Party
36- from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
37- from torchao.dtypes.floatx.float8_layout import (
38- Float8AQTTensorImpl,
39- Float8Layout,
40- Float8MMConfig,
41- )
42- from torchao.quantization.granularity import PerTensor
43- from torchao.quantization.observer import get_block_size
44- from torchao.quantization.quant_primitives import ZeroPointDomain
45- else:
46- TORCHAO_INSTALLED = False
29+ from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor
30+ import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import
4731
4832
4933class MathFP8AttentionKwargs(AttentionKwargs):
5034 """TypedDict for FP8 attention."""
5135
52- mask: NotRequired[Tensor]
36+ mask: NotRequired[torch. Tensor]
5337 do_scale_q: bool
5438 is_causal_mask: bool
5539
5640
57- # TODO: Doesn't quite work yet, more discussion needed
41+ # TODO: Figure out better scales for AIU? These come from vLLM
5842Q_RANGE = 200.0
5943K_RANGE = 200.0
6044V_RANGE = 100.0
6145
6246
63- def _construct_fp8_cache(
64- tensor: Tensor, scale: Tensor, orig_dtype: torch.dtype
65- ) -> AffineQuantizedTensor:
66- """Construct the torchao tensor to save kv cache with its scales."""
67-
68- weight_granularity = PerTensor()
69- fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True))
70- return AffineQuantizedTensor(
71- Float8AQTTensorImpl.from_plain(
72- tensor,
73- scale,
74- None,
75- fp8_layout,
76- ),
77- get_block_size(tensor.shape, weight_granularity),
78- tensor.shape,
79- zero_point_domain=ZeroPointDomain.NONE,
80- dtype=orig_dtype,
81- )
47+ def _construct_fp8_cache(tensor: torch.Tensor, scale: torch.Tensor) -> ScaledTensor:
48+ """Construct the custom object to save KV cache with its scales."""
49+ return ScaledTensor(tensor, scale)
8250
8351
8452def _math_fp8_store_op(
85- keys: Tensor, # pylint: disable=unused-argument
86- values: Tensor,
87- key_cache: Tensor | None,
88- value_cache: Tensor | None,
53+ keys: torch. Tensor, # pylint: disable=unused-argument
54+ values: torch. Tensor,
55+ key_cache: torch. Tensor | None,
56+ value_cache: torch. Tensor | None,
8957 **attn_kwargs: Unpack[MathFP8AttentionKwargs],
90- ) -> tuple[Tensor, Tensor, Tensor, Tensor ]:
58+ ) -> tuple[ScaledTensor, ScaledTensor, ScaledTensor, ScaledTensor ]:
9159 """Implement math of KV cache storing."""
9260
93- orig_dtype = keys.dtype
94-
95- if isinstance(key_cache, AffineQuantizedTensor) and isinstance(
96- value_cache, AffineQuantizedTensor
97- ):
98- k_scale = key_cache.tensor_impl.scale
99- v_scale = value_cache.tensor_impl.scale
61+ if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor):
62+ k_scale = key_cache._scale
63+ v_scale = value_cache._scale
10064 else:
10165 k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32)
10266 v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32)
@@ -105,36 +69,35 @@ def _math_fp8_store_op(
10569 values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1)
10670
10771 if (
108- isinstance(key_cache, AffineQuantizedTensor )
109- and isinstance(value_cache, AffineQuantizedTensor )
72+ isinstance(key_cache, ScaledTensor )
73+ and isinstance(value_cache, ScaledTensor )
11074 and value_cache.numel() > 0
11175 ):
112- key_cache = torch.cat((key_cache.tensor_impl.float8_data , keys), dim=2)
113- value_cache = torch.cat((value_cache.tensor_impl.float8_data , values), dim=2)
114- key_cache = _construct_fp8_cache(key_cache, k_scale, orig_dtype )
115- value_cache = _construct_fp8_cache(value_cache, v_scale, orig_dtype )
76+ key_cache = torch.cat((key_cache._data , keys), dim=2)
77+ value_cache = torch.cat((value_cache._data , values), dim=2)
78+ key_cache = _construct_fp8_cache(key_cache, k_scale)
79+ value_cache = _construct_fp8_cache(value_cache, v_scale)
11680 return (
11781 key_cache,
11882 value_cache,
11983 key_cache,
12084 value_cache,
12185 )
122-
123- keys = _construct_fp8_cache(keys, k_scale, orig_dtype)
124- values = _construct_fp8_cache(values, v_scale, orig_dtype)
86+ keys = _construct_fp8_cache(keys.contiguous(), k_scale)
87+ values = _construct_fp8_cache(values.contiguous(), v_scale)
12588 return (keys, values, keys, values)
12689
12790
12891def _math_fp8_compute_op(
129- query: Tensor,
130- key_cache: Tensor,
131- value_cache: Tensor,
92+ query: torch. Tensor,
93+ key_cache: torch. Tensor,
94+ value_cache: torch. Tensor,
13295 nheads: int,
13396 kvheads: int,
13497 p_dropout: float,
13598 scale_factor: float | None,
13699 **attn_kwargs: Unpack[MathFP8AttentionKwargs],
137- ) -> Tensor:
100+ ) -> torch. Tensor:
138101 """Implement computation of attention BMM, leveraging the custom scaled attention
139102 BMM op that was pre-registered for torch.compile."""
140103
@@ -147,13 +110,11 @@ def _math_fp8_compute_op(
147110
148111 query = query.to(torch.float8_e4m3fn).transpose(2, 1)
149112
150- if isinstance(key_cache, AffineQuantizedTensor) and isinstance(
151- value_cache, AffineQuantizedTensor
152- ):
153- k_scale = key_cache.tensor_impl.scale
154- v_scale = value_cache.tensor_impl.scale
155- key_cache = key_cache.tensor_impl.float8_data
156- value_cache = value_cache.tensor_impl.float8_data
113+ if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor):
114+ k_scale = key_cache._scale
115+ v_scale = value_cache._scale
116+ key_cache = key_cache._data
117+ value_cache = value_cache._data
157118 else:
158119 k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32)
159120 v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32)
0 commit comments