Skip to content

Commit 6f289b0

Browse files
committed
Update FP8 bmm
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent b2c0d54 commit 6f289b0

File tree

4 files changed

+186
-86
lines changed

4 files changed

+186
-86
lines changed
Lines changed: 32 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""FMS registration of attention BMM operation using torch-registered scaled BMM."""
1515

1616
# Standard
17-
from importlib.util import find_spec
1817
from typing import NotRequired, Unpack
1918
import math
2019

@@ -24,79 +23,44 @@
2423
_sdpa_update_attn_kwargs,
2524
register_attention_op,
2625
)
27-
from torch import Tensor
2826
import torch
2927

3028
# Local
31-
import fms_mo.aiu_addons.fp8.fp8_aiu_op # pylint: disable=unused-import
32-
33-
if find_spec("torchao"):
34-
TORCHAO_INSTALLED = True
35-
# Third Party
36-
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
37-
from torchao.dtypes.floatx.float8_layout import (
38-
Float8AQTTensorImpl,
39-
Float8Layout,
40-
Float8MMConfig,
41-
)
42-
from torchao.quantization.granularity import PerTensor
43-
from torchao.quantization.observer import get_block_size
44-
from torchao.quantization.quant_primitives import ZeroPointDomain
45-
else:
46-
TORCHAO_INSTALLED = False
29+
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor
30+
import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import
4731

4832

4933
class MathFP8AttentionKwargs(AttentionKwargs):
5034
"""TypedDict for FP8 attention."""
5135

52-
mask: NotRequired[Tensor]
36+
mask: NotRequired[torch.Tensor]
5337
do_scale_q: bool
5438
is_causal_mask: bool
5539

5640

57-
# TODO: Doesn't quite work yet, more discussion needed
41+
# TODO: Figure out better scales for AIU? These come from vLLM
5842
Q_RANGE = 200.0
5943
K_RANGE = 200.0
6044
V_RANGE = 100.0
6145

6246

63-
def _construct_fp8_cache(
64-
tensor: Tensor, scale: Tensor, orig_dtype: torch.dtype
65-
) -> AffineQuantizedTensor:
66-
"""Construct the torchao tensor to save kv cache with its scales."""
67-
68-
weight_granularity = PerTensor()
69-
fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True))
70-
return AffineQuantizedTensor(
71-
Float8AQTTensorImpl.from_plain(
72-
tensor,
73-
scale,
74-
None,
75-
fp8_layout,
76-
),
77-
get_block_size(tensor.shape, weight_granularity),
78-
tensor.shape,
79-
zero_point_domain=ZeroPointDomain.NONE,
80-
dtype=orig_dtype,
81-
)
47+
def _construct_fp8_cache(tensor: torch.Tensor, scale: torch.Tensor) -> ScaledTensor:
48+
"""Construct the custom object to save KV cache with its scales."""
49+
return ScaledTensor(tensor, scale)
8250

8351

8452
def _math_fp8_store_op(
85-
keys: Tensor, # pylint: disable=unused-argument
86-
values: Tensor,
87-
key_cache: Tensor | None,
88-
value_cache: Tensor | None,
53+
keys: torch.Tensor, # pylint: disable=unused-argument
54+
values: torch.Tensor,
55+
key_cache: torch.Tensor | None,
56+
value_cache: torch.Tensor | None,
8957
**attn_kwargs: Unpack[MathFP8AttentionKwargs],
90-
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
58+
) -> tuple[ScaledTensor, ScaledTensor, ScaledTensor, ScaledTensor]:
9159
"""Implement math of KV cache storing."""
9260

93-
orig_dtype = keys.dtype
94-
95-
if isinstance(key_cache, AffineQuantizedTensor) and isinstance(
96-
value_cache, AffineQuantizedTensor
97-
):
98-
k_scale = key_cache.tensor_impl.scale
99-
v_scale = value_cache.tensor_impl.scale
61+
if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor):
62+
k_scale = key_cache._scale
63+
v_scale = value_cache._scale
10064
else:
10165
k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32)
10266
v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32)
@@ -105,36 +69,35 @@ def _math_fp8_store_op(
10569
values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1)
10670

10771
if (
108-
isinstance(key_cache, AffineQuantizedTensor)
109-
and isinstance(value_cache, AffineQuantizedTensor)
72+
isinstance(key_cache, ScaledTensor)
73+
and isinstance(value_cache, ScaledTensor)
11074
and value_cache.numel() > 0
11175
):
112-
key_cache = torch.cat((key_cache.tensor_impl.float8_data, keys), dim=2)
113-
value_cache = torch.cat((value_cache.tensor_impl.float8_data, values), dim=2)
114-
key_cache = _construct_fp8_cache(key_cache, k_scale, orig_dtype)
115-
value_cache = _construct_fp8_cache(value_cache, v_scale, orig_dtype)
76+
key_cache = torch.cat((key_cache._data, keys), dim=2)
77+
value_cache = torch.cat((value_cache._data, values), dim=2)
78+
key_cache = _construct_fp8_cache(key_cache, k_scale)
79+
value_cache = _construct_fp8_cache(value_cache, v_scale)
11680
return (
11781
key_cache,
11882
value_cache,
11983
key_cache,
12084
value_cache,
12185
)
122-
123-
keys = _construct_fp8_cache(keys, k_scale, orig_dtype)
124-
values = _construct_fp8_cache(values, v_scale, orig_dtype)
86+
keys = _construct_fp8_cache(keys.contiguous(), k_scale)
87+
values = _construct_fp8_cache(values.contiguous(), v_scale)
12588
return (keys, values, keys, values)
12689

12790

12891
def _math_fp8_compute_op(
129-
query: Tensor,
130-
key_cache: Tensor,
131-
value_cache: Tensor,
92+
query: torch.Tensor,
93+
key_cache: torch.Tensor,
94+
value_cache: torch.Tensor,
13295
nheads: int,
13396
kvheads: int,
13497
p_dropout: float,
13598
scale_factor: float | None,
13699
**attn_kwargs: Unpack[MathFP8AttentionKwargs],
137-
) -> Tensor:
100+
) -> torch.Tensor:
138101
"""Implement computation of attention BMM, leveraging the custom scaled attention
139102
BMM op that was pre-registered for torch.compile."""
140103

@@ -147,13 +110,11 @@ def _math_fp8_compute_op(
147110

148111
query = query.to(torch.float8_e4m3fn).transpose(2, 1)
149112

150-
if isinstance(key_cache, AffineQuantizedTensor) and isinstance(
151-
value_cache, AffineQuantizedTensor
152-
):
153-
k_scale = key_cache.tensor_impl.scale
154-
v_scale = value_cache.tensor_impl.scale
155-
key_cache = key_cache.tensor_impl.float8_data
156-
value_cache = value_cache.tensor_impl.float8_data
113+
if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor):
114+
k_scale = key_cache._scale
115+
v_scale = value_cache._scale
116+
key_cache = key_cache._data
117+
value_cache = value_cache._data
157118
else:
158119
k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32)
159120
v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32)

fms_mo/aiu_addons/fp8/fp8_linear.py

Whitespace-only changes.
Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# abstract op must be registered with specific I/O, even if not in use by the op function
2222

2323

24-
@torch.library.custom_op("sendnn::scaled_bmm", mutates_args=())
24+
@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
2525
def sendnn_scaled_bmm(
2626
mat1: Tensor,
2727
mat2: Tensor,
@@ -38,17 +38,8 @@ def sendnn_scaled_bmm(
3838
assert (
3939
mat1.shape[:-2] == mat2.shape[:-2]
4040
), "batch dimensions must match for mat1 and mat2"
41-
assert (
42-
mat1.shape[:-2] == scale1.shape[:-2]
43-
), "batch dimensions must match for mat1 and scale1"
44-
assert (
45-
mat2.shape[:-2] == scale2.shape[:-2]
46-
), "batch dimensions must match for mat2 and scale2"
47-
4841
mat1 = mat1.view(-1, *mat1.shape[-2:])
4942
mat2 = mat2.view(-1, *mat2.shape[-2:])
50-
scale1 = scale1.view(-1, *scale1.shape[-2:])
51-
scale2 = scale2.view(-1, *scale2.shape[-2:])
5243
out = torch.empty(
5344
(mat1.shape[0], mat1.shape[1], mat2.shape[2]),
5445
dtype=out_dtype,
@@ -58,12 +49,12 @@ def sendnn_scaled_bmm(
5849
out[b_idx] = torch._scaled_mm(
5950
mat1[b_idx],
6051
mat2[b_idx],
61-
scale1[b_idx],
62-
scale2[b_idx],
63-
out_dtype,
64-
use_fast_accum,
52+
scale1,
53+
scale2,
54+
out_dtype=out_dtype,
55+
use_fast_accum=use_fast_accum,
6556
)
66-
return out
57+
return out.view(*mat1.shape[:-2], mat1.shape[1], mat2.shape[2])
6758

6859

6960
@sendnn_scaled_bmm.register_fake

fms_mo/aiu_addons/fp8/fp8_utils.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""FMS registration of attention BMM operation using torch-registered scaled BMM."""
15+
16+
# Standard
17+
import functools
18+
19+
# Third Party
20+
import torch
21+
22+
# pylint: disable=unused-argument
23+
# unusued arguments are needed for templates
24+
25+
26+
_HANDLED_FUNCTIONS = {}
27+
28+
29+
def _implements(torch_function):
30+
"""Register a torch function override"""
31+
32+
def decorator(func):
33+
@functools.wraps(torch_function)
34+
def wrapper(f, types, args, kwargs):
35+
return func(f, types, args, kwargs)
36+
37+
_HANDLED_FUNCTIONS[torch_function] = wrapper
38+
return func
39+
40+
return decorator
41+
42+
43+
class ScaledTensor(torch.Tensor):
44+
"""Representation of a quantized tensor and its scale."""
45+
46+
def __new__(
47+
cls,
48+
data: torch.Tensor,
49+
scale: torch.Tensor,
50+
):
51+
return torch.Tensor._make_wrapper_subclass(
52+
cls,
53+
data.size(),
54+
strides=data.stride(),
55+
storage_offset=data.storage_offset(),
56+
dtype=data.dtype,
57+
layout=data.layout,
58+
requires_grad=data.requires_grad,
59+
device=data.device,
60+
)
61+
62+
def __init__(
63+
self,
64+
data: torch.Tensor,
65+
scale: torch.Tensor,
66+
):
67+
self._data = data
68+
self._scale = scale
69+
70+
def __tensor_flatten__(self):
71+
ctx = {}
72+
return ["_data", "_scale"], ctx
73+
74+
@staticmethod
75+
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
76+
assert len(inner_tensors) == 2
77+
return ScaledTensor(
78+
inner_tensors["_data"],
79+
inner_tensors["_scale"],
80+
)
81+
82+
@classmethod
83+
def __torch_dispatch__(cls, func, types, args, kwargs=None):
84+
if func in _HANDLED_FUNCTIONS:
85+
return _HANDLED_FUNCTIONS[func](func, types, args, kwargs)
86+
87+
arg_types = tuple(type(arg) for arg in args)
88+
kwarg_types = {k: type(arg) for k, arg in kwargs.items()}
89+
raise NotImplementedError(
90+
f"{cls.__name__} dispatch: attempting to run unimplemented "
91+
f"operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}"
92+
)
93+
94+
def __repr__(self):
95+
return f"{self._data.__repr__()}\n{self._scale.__repr__()}"
96+
97+
98+
def _infer_quantization_config(quant_config: dict) -> dict | None:
99+
# There's many quantization packages compatible with HF
100+
# We initially focus on llm-compressor as it is the one used in FMS-MO
101+
102+
# llm-compressor saves its checkpoints with quant_method = compressed-tensors
103+
# quantization_status tells us whether the model has already been quantized
104+
# We only support loading already quantized models (compressed status)
105+
if (
106+
quant_config["quant_method"] == "compressed-tensors"
107+
and quant_config["quantization_status"] == "compressed"
108+
):
109+
# FP8 quantization will have FP8 weights
110+
# We assume a single quantization group (group_0), to follow fms-mo checkpoints
111+
# num_bits and type tells us "float" with "8" bits, aka FP8
112+
if (
113+
quant_config["config_groups"]["group_0"]["weights"]["type"] == "float"
114+
and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8
115+
):
116+
# This is used by get_linear to decide whether a linear layer
117+
# will be quantized or not inside the model
118+
def fp8_linear_type(name: str) -> str:
119+
# We need to translate HF names to FMS names
120+
translations = {
121+
"lm_head": "head",
122+
}
123+
for ignored_layer in quant_config["ignore"]:
124+
assert isinstance(ignored_layer, str)
125+
fms_ign_layer = translations.get(ignored_layer, ignored_layer)
126+
if name in fms_ign_layer:
127+
return "torch_linear"
128+
for pattern in quant_config["config_groups"]["group_0"]["targets"]:
129+
# Special case from llm-compressor that covers all linear layers
130+
# not in the ignore pattern
131+
assert isinstance(pattern, str)
132+
if pattern == "Linear":
133+
return "fp8"
134+
if name in translations.get(pattern, pattern):
135+
return "fp8"
136+
return "torch_linear"
137+
138+
return {
139+
"linear_type": fp8_linear_type,
140+
"input_activations": quant_config["config_groups"]["group_0"][
141+
"input_activations"
142+
],
143+
"output_activations": quant_config["config_groups"]["group_0"][
144+
"output_activations"
145+
],
146+
"weights": quant_config["config_groups"]["group_0"]["weights"],
147+
}
148+
return None

0 commit comments

Comments
 (0)