Skip to content

Commit b2c0d54

Browse files
committed
Addons for FP8 attention bmm in FMS
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent bcee5f3 commit b2c0d54

File tree

3 files changed

+312
-0
lines changed

3 files changed

+312
-0
lines changed

fms_mo/aiu_addons/fp8/__init__.py

Whitespace-only changes.
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""FMS registration of attention BMM operation using torch-registered scaled BMM."""
15+
16+
# Standard
17+
from importlib.util import find_spec
18+
from typing import NotRequired, Unpack
19+
import math
20+
21+
# Third Party
22+
from fms.modules.attention import (
23+
AttentionKwargs,
24+
_sdpa_update_attn_kwargs,
25+
register_attention_op,
26+
)
27+
from torch import Tensor
28+
import torch
29+
30+
# Local
31+
import fms_mo.aiu_addons.fp8.fp8_aiu_op # pylint: disable=unused-import
32+
33+
if find_spec("torchao"):
34+
TORCHAO_INSTALLED = True
35+
# Third Party
36+
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
37+
from torchao.dtypes.floatx.float8_layout import (
38+
Float8AQTTensorImpl,
39+
Float8Layout,
40+
Float8MMConfig,
41+
)
42+
from torchao.quantization.granularity import PerTensor
43+
from torchao.quantization.observer import get_block_size
44+
from torchao.quantization.quant_primitives import ZeroPointDomain
45+
else:
46+
TORCHAO_INSTALLED = False
47+
48+
49+
class MathFP8AttentionKwargs(AttentionKwargs):
50+
"""TypedDict for FP8 attention."""
51+
52+
mask: NotRequired[Tensor]
53+
do_scale_q: bool
54+
is_causal_mask: bool
55+
56+
57+
# TODO: Doesn't quite work yet, more discussion needed
58+
Q_RANGE = 200.0
59+
K_RANGE = 200.0
60+
V_RANGE = 100.0
61+
62+
63+
def _construct_fp8_cache(
64+
tensor: Tensor, scale: Tensor, orig_dtype: torch.dtype
65+
) -> AffineQuantizedTensor:
66+
"""Construct the torchao tensor to save kv cache with its scales."""
67+
68+
weight_granularity = PerTensor()
69+
fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True))
70+
return AffineQuantizedTensor(
71+
Float8AQTTensorImpl.from_plain(
72+
tensor,
73+
scale,
74+
None,
75+
fp8_layout,
76+
),
77+
get_block_size(tensor.shape, weight_granularity),
78+
tensor.shape,
79+
zero_point_domain=ZeroPointDomain.NONE,
80+
dtype=orig_dtype,
81+
)
82+
83+
84+
def _math_fp8_store_op(
85+
keys: Tensor, # pylint: disable=unused-argument
86+
values: Tensor,
87+
key_cache: Tensor | None,
88+
value_cache: Tensor | None,
89+
**attn_kwargs: Unpack[MathFP8AttentionKwargs],
90+
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
91+
"""Implement math of KV cache storing."""
92+
93+
orig_dtype = keys.dtype
94+
95+
if isinstance(key_cache, AffineQuantizedTensor) and isinstance(
96+
value_cache, AffineQuantizedTensor
97+
):
98+
k_scale = key_cache.tensor_impl.scale
99+
v_scale = value_cache.tensor_impl.scale
100+
else:
101+
k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32)
102+
v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32)
103+
104+
keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1)
105+
values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1)
106+
107+
if (
108+
isinstance(key_cache, AffineQuantizedTensor)
109+
and isinstance(value_cache, AffineQuantizedTensor)
110+
and value_cache.numel() > 0
111+
):
112+
key_cache = torch.cat((key_cache.tensor_impl.float8_data, keys), dim=2)
113+
value_cache = torch.cat((value_cache.tensor_impl.float8_data, values), dim=2)
114+
key_cache = _construct_fp8_cache(key_cache, k_scale, orig_dtype)
115+
value_cache = _construct_fp8_cache(value_cache, v_scale, orig_dtype)
116+
return (
117+
key_cache,
118+
value_cache,
119+
key_cache,
120+
value_cache,
121+
)
122+
123+
keys = _construct_fp8_cache(keys, k_scale, orig_dtype)
124+
values = _construct_fp8_cache(values, v_scale, orig_dtype)
125+
return (keys, values, keys, values)
126+
127+
128+
def _math_fp8_compute_op(
129+
query: Tensor,
130+
key_cache: Tensor,
131+
value_cache: Tensor,
132+
nheads: int,
133+
kvheads: int,
134+
p_dropout: float,
135+
scale_factor: float | None,
136+
**attn_kwargs: Unpack[MathFP8AttentionKwargs],
137+
) -> Tensor:
138+
"""Implement computation of attention BMM, leveraging the custom scaled attention
139+
BMM op that was pre-registered for torch.compile."""
140+
141+
orig_dtype = query.dtype
142+
143+
q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device)
144+
if attn_kwargs.get("do_scale_q", False):
145+
q_scale.copy_(torch.abs(query).max() / Q_RANGE)
146+
query = query / q_scale
147+
148+
query = query.to(torch.float8_e4m3fn).transpose(2, 1)
149+
150+
if isinstance(key_cache, AffineQuantizedTensor) and isinstance(
151+
value_cache, AffineQuantizedTensor
152+
):
153+
k_scale = key_cache.tensor_impl.scale
154+
v_scale = value_cache.tensor_impl.scale
155+
key_cache = key_cache.tensor_impl.float8_data
156+
value_cache = value_cache.tensor_impl.float8_data
157+
else:
158+
k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32)
159+
v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32)
160+
key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn)
161+
value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn)
162+
163+
# no longer transposing prior to store, so need to check this in case of no cache
164+
# TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here
165+
if key_cache.shape[1] != kvheads and key_cache.shape[2] == kvheads:
166+
key_cache = key_cache.transpose(2, 1)
167+
value_cache = value_cache.transpose(2, 1)
168+
169+
mask = attn_kwargs.get("mask", None)
170+
if mask is not None:
171+
# Our expected mask format is bs x q_len x k_len, so to make it broadcastable
172+
# we need to create the nheads dimension
173+
while len(mask.size()) != 4: # expects bs (x nheads) x q_len x kv_len
174+
mask = mask.unsqueeze(1)
175+
176+
L, S = query.size(-2), key_cache.size(-2)
177+
scale_factor = (
178+
1 / math.sqrt(query.size(-1)) if scale_factor is None else scale_factor
179+
)
180+
attn_bias = torch.zeros(L, S, dtype=orig_dtype, device=query.device)
181+
if attn_kwargs.get("is_causal_mask", False):
182+
assert mask is None
183+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
184+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
185+
attn_bias.to(torch.float32)
186+
187+
if mask is not None:
188+
if mask.dtype == torch.bool:
189+
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
190+
else:
191+
attn_bias = mask + attn_bias
192+
193+
expansion = nheads // kvheads
194+
if expansion > 1:
195+
key_cache = key_cache.repeat_interleave(
196+
query.size(-3) // key_cache.size(-3), -3
197+
)
198+
value_cache = value_cache.repeat_interleave(
199+
query.size(-3) // value_cache.size(-3), -3
200+
)
201+
202+
attn_weight = (
203+
torch.ops.sendnn.scaled_bmm(
204+
query,
205+
key_cache.transpose(-2, -1),
206+
q_scale,
207+
k_scale,
208+
out_dtype=orig_dtype,
209+
use_fast_accum=True,
210+
)
211+
* scale_factor
212+
)
213+
attn_weight += attn_bias
214+
attn_weight = torch.softmax(attn_weight, dim=-1)
215+
attn_weight = torch.dropout(attn_weight, p_dropout, train=True)
216+
# Do matmul in orig_dtype
217+
attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale)
218+
219+
attn = attn.to(orig_dtype).transpose(2, 1).contiguous()
220+
return attn
221+
222+
223+
register_attention_op(
224+
"math_fp8",
225+
_math_fp8_store_op,
226+
_math_fp8_compute_op,
227+
update_attn_kwargs_op=_sdpa_update_attn_kwargs,
228+
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Torch registration of FP8xFP8 operation for attention BMMs."""
15+
16+
# Third Party
17+
from torch import Tensor
18+
import torch
19+
20+
# pylint: disable=unused-argument
21+
# abstract op must be registered with specific I/O, even if not in use by the op function
22+
23+
24+
@torch.library.custom_op("sendnn::scaled_bmm", mutates_args=())
25+
def sendnn_scaled_bmm(
26+
mat1: Tensor,
27+
mat2: Tensor,
28+
scale1: Tensor,
29+
scale2: Tensor,
30+
out_dtype: torch.dtype | None = None,
31+
use_fast_accum: bool = False,
32+
) -> Tensor:
33+
"""Implement a custom scaled attention BMM op: a batched version of _scaled_mm.
34+
The operations that are part of this function are not exposed to the computational
35+
graph, but are invoked when running on non-AIU devices.
36+
"""
37+
38+
assert (
39+
mat1.shape[:-2] == mat2.shape[:-2]
40+
), "batch dimensions must match for mat1 and mat2"
41+
assert (
42+
mat1.shape[:-2] == scale1.shape[:-2]
43+
), "batch dimensions must match for mat1 and scale1"
44+
assert (
45+
mat2.shape[:-2] == scale2.shape[:-2]
46+
), "batch dimensions must match for mat2 and scale2"
47+
48+
mat1 = mat1.view(-1, *mat1.shape[-2:])
49+
mat2 = mat2.view(-1, *mat2.shape[-2:])
50+
scale1 = scale1.view(-1, *scale1.shape[-2:])
51+
scale2 = scale2.view(-1, *scale2.shape[-2:])
52+
out = torch.empty(
53+
(mat1.shape[0], mat1.shape[1], mat2.shape[2]),
54+
dtype=out_dtype,
55+
device=mat1.device,
56+
)
57+
for b_idx in range(mat1.shape[0]):
58+
out[b_idx] = torch._scaled_mm(
59+
mat1[b_idx],
60+
mat2[b_idx],
61+
scale1[b_idx],
62+
scale2[b_idx],
63+
out_dtype,
64+
use_fast_accum,
65+
)
66+
return out
67+
68+
69+
@sendnn_scaled_bmm.register_fake
70+
def _(
71+
mat1: Tensor,
72+
mat2: Tensor,
73+
scale1: Tensor,
74+
scale2: Tensor,
75+
out_dtype: torch.dtype | None = None,
76+
use_fast_accum: bool = False,
77+
) -> Tensor:
78+
"""Template for scaled attention BMM operation. I/O retain the expected size."""
79+
80+
return torch.empty(
81+
(*mat1.shape[:-2], mat1.shape[-2], mat2.shape[-1]),
82+
dtype=out_dtype,
83+
device=mat1.device,
84+
)

0 commit comments

Comments
 (0)