Skip to content

Commit 6c54b37

Browse files
authored
Merge pull request #154 from ani300/fp8_attn_addon
feat: addons for FP8 attention bmm, paged attention, and linear in FMS
2 parents 1945e07 + 42528b0 commit 6c54b37

File tree

11 files changed

+1181
-4
lines changed

11 files changed

+1181
-4
lines changed

fms_mo/aiu_addons/__init__.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
def _infer_quantization_config(quant_config: dict) -> dict | None:
2+
"""Construct linear_config dictionary carrying FP8 configuration for FMS.
3+
4+
There's many quantization packages compatible with HF
5+
We initially focus on llm-compressor as it is the one used in FMS-MO
6+
7+
llm-compressor saves its checkpoints with quant_method = compressed-tensors
8+
quantization_status tells us whether the model has already been quantized
9+
We only support loading already quantized models (compressed status)
10+
"""
11+
12+
if (
13+
quant_config["quant_method"] == "compressed-tensors"
14+
and quant_config["quantization_status"] == "compressed"
15+
):
16+
# FP8 quantization will have FP8 weights
17+
# We assume a single quantization group (group_0), to follow fms-mo checkpoints
18+
# num_bits and type tells us "float" with "8" bits, aka FP8
19+
if (
20+
quant_config["config_groups"]["group_0"]["weights"]["type"] == "float"
21+
and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8
22+
):
23+
# First, import required FP8 linear classes from fms-mo
24+
# Local
25+
import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import
26+
import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import
27+
28+
# This is used by get_linear to decide whether a linear layer
29+
# will be quantized or not inside the model
30+
def fp8_linear_type(name: str) -> str:
31+
# We need to translate HF names to FMS names
32+
translations = {
33+
"lm_head": "head",
34+
}
35+
for ignored_layer in quant_config["ignore"]:
36+
assert isinstance(ignored_layer, str)
37+
fms_ign_layer = translations.get(ignored_layer, ignored_layer)
38+
if name in fms_ign_layer:
39+
return "torch_linear"
40+
for pattern in quant_config["config_groups"]["group_0"]["targets"]:
41+
# Special case from llm-compressor that covers all linear layers
42+
# not in the ignore pattern
43+
assert isinstance(pattern, str)
44+
if pattern == "Linear":
45+
return "fp8"
46+
if name in translations.get(pattern, pattern):
47+
return "fp8"
48+
return "torch_linear"
49+
50+
return {
51+
"linear_type": fp8_linear_type,
52+
"input_activations": quant_config["config_groups"]["group_0"][
53+
"input_activations"
54+
],
55+
"output_activations": quant_config["config_groups"]["group_0"][
56+
"output_activations"
57+
],
58+
"weights": quant_config["config_groups"]["group_0"]["weights"],
59+
}
60+
return None

fms_mo/aiu_addons/fp8/__init__.py

Whitespace-only changes.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Implement and register FMS adapters for FP8 checkpoint loading."""
15+
16+
# Standard
17+
from typing import Any, Mapping
18+
import functools
19+
20+
# Local
21+
from fms_mo.prep import available_packages
22+
23+
if available_packages["fms"]:
24+
# Third Party
25+
from fms.modules.linear import get_linear_type
26+
from fms.utils import serialization
27+
from fms.utils.config import ModelConfig
28+
29+
# pylint: disable=unused-argument
30+
# Retaining kwargs input arguments for consistency with other adapter steps.
31+
# TODO: may be shared with gptq llama
32+
def _hf_fp8_check(
33+
input_sd: Mapping[str, Any],
34+
model_config: ModelConfig | None = None,
35+
checkpoint_is_fused: bool = False,
36+
**kwargs,
37+
) -> Mapping[str, Any]:
38+
"""Implementation of adapter step for FMS: ensure that when FP8 quantization
39+
is in use, weights are fused like the model checkpoint.
40+
"""
41+
42+
has_fused_weights = True
43+
linear_type = "torch_linear"
44+
if model_config:
45+
if not model_config.fused_weights:
46+
has_fused_weights = False
47+
if model_config.linear_config:
48+
linear_type = model_config.linear_config["linear_type"]
49+
if callable(linear_type):
50+
# Calling this function with "any" guarantees "fp8" to be returned
51+
# when loading an HF fp8 checkpoint, and never in any other condition
52+
linear_type = get_linear_type(model_config.linear_config, "any")
53+
54+
if "fp8" in linear_type and has_fused_weights != checkpoint_is_fused:
55+
raise ValueError(
56+
"FP8 HF llama checkpoints cannot be loaded into a model with fused weights"
57+
)
58+
59+
return input_sd
60+
61+
serialization.register_adapter_step(
62+
"llama",
63+
"hf_fp8_check",
64+
functools.partial(_hf_fp8_check, checkpoint_is_fused=False),
65+
)
66+
serialization.extend_adapter("llama", "hf", ["hf_fp8_check"])
67+
68+
serialization.register_adapter_step(
69+
"granite",
70+
"hf_fp8_check",
71+
functools.partial(_hf_fp8_check, checkpoint_is_fused=False),
72+
)
73+
serialization.extend_adapter("granite", "hf", ["hf_fp8_check"])

0 commit comments

Comments
 (0)