Skip to content

Commit 496bf44

Browse files
committed
Add FP8 linear to FMS addon
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 30f76a9 commit 496bf44

File tree

3 files changed

+339
-6
lines changed

3 files changed

+339
-6
lines changed

fms_mo/aiu_addons/fp8/fp8_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from fms.utils import serialization
2222
from fms.utils.config import ModelConfig
2323

24+
# pylint: disable=unused-argument
25+
# Retaining kwargs input arguments for consistency.
26+
27+
2428
# NOTE: this adapter step must be registered before the adapter that uses it (such as
2529
# the llama adapter in fms.models.llama)
2630
# TODO: may be shared with gptq llama
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Implement FP8 linear module to be loaded via FMS."""
15+
16+
# Standard
17+
from importlib.util import find_spec
18+
from typing import Any, Mapping
19+
20+
# Third Party
21+
from fms.modules.linear import (
22+
LinearModuleShardingInfo,
23+
LinearParameterShardingInfo,
24+
register_linear_type_to_module_map,
25+
register_linear_type_to_sharding_map,
26+
shard_base_linear,
27+
)
28+
from fms.modules.tp import ShardType, TPModule
29+
import torch
30+
31+
# pylint: disable=not-callable
32+
# torch.nn.functional.linear not recognized as callable
33+
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
34+
35+
36+
### FP8 linear layers
37+
if find_spec("torchao"):
38+
TORCHAO_INSTALLED = True
39+
40+
# Third Party
41+
from torchao.dtypes.affine_quantized_tensor import ( # type: ignore
42+
AffineQuantizedTensor,
43+
to_affine_quantized_floatx,
44+
to_affine_quantized_floatx_static,
45+
)
46+
from torchao.dtypes.floatx.float8_layout import ( # type: ignore
47+
Float8AQTTensorImpl,
48+
Float8Layout,
49+
Float8MMConfig,
50+
preprocess_data,
51+
preprocess_scale,
52+
)
53+
from torchao.dtypes.utils import get_out_shape # type: ignore
54+
from torchao.float8.inference import ( # type: ignore
55+
_is_rowwise_scaled,
56+
addmm_float8_unwrapped_inference,
57+
)
58+
from torchao.quantization.granularity import PerRow, PerTensor # type: ignore
59+
from torchao.quantization.observer import get_block_size # type: ignore
60+
from torchao.quantization.quant_primitives import ZeroPointDomain # type: ignore
61+
else:
62+
TORCHAO_INSTALLED = False
63+
64+
65+
class FP8Linear(torch.nn.Module):
66+
"""Class handles FP8 weights loading and uses torchao for the matmuls."""
67+
68+
def __init__(
69+
self,
70+
in_features: int,
71+
out_features: int,
72+
bias: bool,
73+
linear_config: Mapping[str, Any],
74+
):
75+
super().__init__()
76+
77+
self.in_features = in_features
78+
self.out_features = out_features
79+
self.has_bias = bias
80+
self.linear_config = linear_config
81+
82+
assert (
83+
self.linear_config["weights"] is not None
84+
), "Weights must always be quantized for FP8Linear"
85+
assert self.linear_config["weights"][
86+
"symmetric"
87+
], "We only support symmetric weights for now"
88+
assert not self.linear_config["weights"][
89+
"dynamic"
90+
], "We only support pre-quantized weights for now"
91+
92+
self.weight = torch.nn.Parameter(
93+
torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn),
94+
requires_grad=False,
95+
)
96+
97+
weight_scale_shape = (
98+
(1,)
99+
if self.linear_config["weights"]["strategy"] == "tensor"
100+
else (out_features, 1)
101+
)
102+
self.weight_scale = torch.nn.Parameter(
103+
torch.ones(weight_scale_shape), requires_grad=False
104+
)
105+
106+
self.has_bias = bias
107+
if self.has_bias:
108+
self.bias = torch.nn.Parameter(torch.zeros((out_features,)))
109+
110+
if (
111+
self.linear_config["input_activations"] is not None
112+
and not self.linear_config["input_activations"]["dynamic"]
113+
):
114+
input_scale_shape = (
115+
(1,)
116+
if self.linear_config["input_activations"]["strategy"] == "tensor"
117+
else (out_features, 1)
118+
)
119+
self.input_scale = torch.nn.Parameter(
120+
torch.ones(input_scale_shape), requires_grad=False
121+
)
122+
123+
def _input_activation_quant_func_fp8(
124+
self,
125+
x: torch.Tensor,
126+
activation_granularity,
127+
activation_dtype: torch.dtype,
128+
scale: torch.Tensor | None = None,
129+
):
130+
"""Quantize the input activation tensor for an aqt_float variant.
131+
If scale is not provided, it will be dynamically calculated, otherwise the
132+
provided scale will be used.
133+
"""
134+
135+
block_size = get_block_size(x.shape, activation_granularity)
136+
if scale is None:
137+
activation = to_affine_quantized_floatx(
138+
input_float=x,
139+
block_size=block_size,
140+
target_dtype=activation_dtype,
141+
scale_dtype=torch.float32,
142+
_layout=Float8Layout(mm_config=None), # Config is stored on weight
143+
)
144+
else:
145+
assert isinstance(
146+
activation_granularity, PerTensor
147+
), "Static quantization only supports PerTensor granularity"
148+
activation = to_affine_quantized_floatx_static(
149+
input_float=x,
150+
block_size=block_size,
151+
scale=scale,
152+
target_dtype=activation_dtype,
153+
_layout=Float8Layout(mm_config=None), # Config is stored on weight
154+
)
155+
return activation
156+
157+
def _construct_qweight_structure(self) -> "AffineQuantizedTensor":
158+
"""Construct the torchao machinery for the fp8 matmul"""
159+
160+
weight_granularity = (
161+
PerTensor()
162+
if self.linear_config["weights"]["strategy"] == "tensor"
163+
else PerRow()
164+
)
165+
fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True))
166+
return AffineQuantizedTensor(
167+
Float8AQTTensorImpl.from_plain(
168+
self.weight,
169+
self.weight_scale.squeeze().to(torch.float32),
170+
None,
171+
fp8_layout,
172+
),
173+
get_block_size(self.weight.shape, weight_granularity),
174+
self.weight.shape,
175+
zero_point_domain=ZeroPointDomain.NONE,
176+
dtype=self.weight_scale.dtype,
177+
)
178+
179+
def forward(self, x: torch.Tensor) -> torch.Tensor:
180+
"""If input quantization is active, compute FP8xFP8 addmm."""
181+
182+
# fp8 weight tensor for torchao
183+
qweight: AffineQuantizedTensor = self._construct_qweight_structure()
184+
185+
if self.linear_config["input_activations"] is not None:
186+
# activations are also fp8, quantize as required by model
187+
act_granularity = (
188+
PerTensor()
189+
if self.linear_config["input_activations"]["strategy"] == "tensor"
190+
else PerRow()
191+
)
192+
input_quant_kwargs = {
193+
"activation_granularity": act_granularity,
194+
"activation_dtype": torch.float8_e4m3fn,
195+
}
196+
if not self.linear_config["input_activations"]["dynamic"]:
197+
input_quant_kwargs["scale"] = self.input_scale.squeeze().to(
198+
torch.float32
199+
)
200+
qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs)
201+
202+
# Copied from torchao _linear_fp8_act_fp8_weight_impl (with changes to support fp8 out)
203+
scaled_mm_config = Float8MMConfig(use_fast_accum=True)
204+
out_shape = get_out_shape(qx.shape, qweight.shape)
205+
206+
# Weight tensor preprocessing
207+
w_tensor_impl = qweight.tensor_impl
208+
assert not w_tensor_impl.transposed, "Weight tensor must be contiguous"
209+
w_data = w_tensor_impl.float8_data
210+
w_scale = w_tensor_impl.scale
211+
212+
# Input tensor preprocessing
213+
inpt_data = qx.tensor_impl.float8_data
214+
input_scale = qx.tensor_impl.scale
215+
# Handle case where input tensor is more than 2D
216+
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
217+
218+
# Handle rowwise case
219+
if _is_rowwise_scaled(qweight):
220+
assert _is_rowwise_scaled(qx), "Input tensor must be rowwise block size"
221+
w_scale = w_scale.unsqueeze(-1).T
222+
input_scale = preprocess_scale(input_scale, qx.shape)
223+
224+
# Preprocess data
225+
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
226+
227+
# Perform the computation
228+
return addmm_float8_unwrapped_inference(
229+
inpt_data,
230+
input_scale,
231+
w_data,
232+
w_scale,
233+
output_dtype=qx.dtype,
234+
bias=getattr(self, "bias", None),
235+
use_fast_accum=scaled_mm_config.use_fast_accum,
236+
).reshape(out_shape)
237+
238+
# activations not quantized, dequant fp8 weight and do regular matmul
239+
out = torch.nn.functional.linear(
240+
x, qweight.dequantize(), self.bias if self.has_bias else None
241+
)
242+
return out
243+
244+
def __repr__(self) -> str:
245+
return (
246+
f"{self.__class__.__name__}"
247+
f"(in={self.in_features}, out={self.out_features}, "
248+
f"bias={self.has_bias}, fp8_config={self.linear_config})"
249+
)
250+
251+
252+
def get_fp8_linear(
253+
in_features: int,
254+
out_features: int,
255+
bias: bool,
256+
linear_config: Mapping[str, Any],
257+
) -> FP8Linear:
258+
"""Retrieve an FP8 Linear module"""
259+
260+
if not TORCHAO_INSTALLED:
261+
raise ModuleNotFoundError("You need to install torchao for FP8 support in FMS!")
262+
263+
return FP8Linear(in_features, out_features, bias, linear_config)
264+
265+
266+
def shard_fp8_linear(
267+
tensor_values: dict[str, torch.Tensor],
268+
tp_module: TPModule,
269+
module_sharding_info: dict[str, LinearModuleShardingInfo],
270+
) -> set | None:
271+
"""
272+
| GPU |
273+
sharding | param | shard | dim |
274+
----------+----------------+-------+-----|
275+
colwise | weight | Y | 0 |
276+
| weight_scale | N | - |
277+
| input_scale | N | - |
278+
| bias | Y | 0 |
279+
----------+----------------+-------+-----|
280+
rowwise | weight | Y | 1 |
281+
| weight_scale | Y/N | 0/- |
282+
| input_scale | Y/N | 0/- |
283+
| bias | 0 | - |
284+
"""
285+
param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {}
286+
for module_name, module_info in module_sharding_info.items():
287+
linear_mod: torch.nn.Module = module_info.linear_module
288+
weight_strategy = getattr(linear_mod, "linear_config")["input_activations"][
289+
"strategy"
290+
]
291+
# Scales are per-row or per-tensor
292+
# Only sharding needed when row parallel and per-row
293+
shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1
294+
params: dict[str, LinearParameterShardingInfo] = {
295+
"weight": LinearParameterShardingInfo(
296+
module_info.sharding_dim, ShardType.SHARD
297+
),
298+
"weight_scale": LinearParameterShardingInfo(
299+
module_info.sharding_dim,
300+
ShardType.SHARD if shard_scales else ShardType.CLONE,
301+
),
302+
}
303+
if hasattr(linear_mod, "input_scale"):
304+
params["input_scale"] = LinearParameterShardingInfo(
305+
module_info.sharding_dim,
306+
ShardType.SHARD if shard_scales else ShardType.CLONE,
307+
)
308+
if hasattr(linear_mod, "bias") and linear_mod.bias is not None:
309+
params["bias"] = LinearParameterShardingInfo(
310+
module_info.sharding_dim,
311+
ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
312+
)
313+
param_sharding_info[module_name] = params
314+
315+
unused_keys = shard_base_linear(
316+
tensor_values,
317+
tp_module,
318+
module_sharding_info,
319+
param_sharding_info,
320+
)
321+
return unused_keys
322+
323+
324+
register_linear_type_to_module_map("fp8", get_fp8_linear)
325+
register_linear_type_to_sharding_map("fp8", shard_fp8_linear)

fms_mo/aiu_addons/fp8/fp8_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __new__(
5959
device=data.device,
6060
)
6161

62-
def __init__(
62+
def __init__( # pylint: disable=super-init-not-called
6363
self,
6464
data: torch.Tensor,
6565
scale: torch.Tensor,
@@ -96,12 +96,16 @@ def __repr__(self):
9696

9797

9898
def _infer_quantization_config(quant_config: dict) -> dict | None:
99-
# There's many quantization packages compatible with HF
100-
# We initially focus on llm-compressor as it is the one used in FMS-MO
99+
"""Construct linear_config dictionary carrying FP8 configuration for FMS.
100+
101+
There's many quantization packages compatible with HF
102+
We initially focus on llm-compressor as it is the one used in FMS-MO
103+
104+
llm-compressor saves its checkpoints with quant_method = compressed-tensors
105+
quantization_status tells us whether the model has already been quantized
106+
We only support loading already quantized models (compressed status)
107+
"""
101108

102-
# llm-compressor saves its checkpoints with quant_method = compressed-tensors
103-
# quantization_status tells us whether the model has already been quantized
104-
# We only support loading already quantized models (compressed status)
105109
if (
106110
quant_config["quant_method"] == "compressed-tensors"
107111
and quant_config["quantization_status"] == "compressed"

0 commit comments

Comments
 (0)