Skip to content

Commit 30f76a9

Browse files
committed
Add FP8 adapter step
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent 6f289b0 commit 30f76a9

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Implement and register FMS adapters for FP8 checkpoint loading."""
15+
16+
# Standard
17+
from typing import Any, Mapping
18+
19+
# Third Party
20+
from fms.modules.linear import get_linear_type
21+
from fms.utils import serialization
22+
from fms.utils.config import ModelConfig
23+
24+
# NOTE: this adapter step must be registered before the adapter that uses it (such as
25+
# the llama adapter in fms.models.llama)
26+
# TODO: may be shared with gptq llama
27+
# TODO: generalize across architectures if possible
28+
def _hf_fp8_llama_check(
29+
input_sd: Mapping[str, Any], model_config: ModelConfig | None = None, **kwargs
30+
) -> Mapping[str, Any]:
31+
"""Implementation of adapter step for FMS Llama: ensure that when FP8 quantization
32+
is in use, weights are unfused.
33+
"""
34+
35+
has_fused_weights = True
36+
linear_type = "torch_linear"
37+
if model_config:
38+
if not model_config.fused_weights:
39+
has_fused_weights = False
40+
if model_config.linear_config:
41+
linear_type = model_config.linear_config["linear_type"]
42+
if callable(linear_type):
43+
# Calling this with "any" guarantees "fp8" to be returned
44+
# when loading an HF fp8 checkpoint, and never in any other condition
45+
linear_type = get_linear_type(model_config.linear_config, "any")
46+
47+
if "fp8" in linear_type and has_fused_weights:
48+
raise ValueError(
49+
"FP8 HF llama checkpoints cannot be loaded into a model with fused weights"
50+
)
51+
52+
return input_sd
53+
54+
55+
serialization.register_adapter_step("llama", "hf_fp8_llama_check", _hf_fp8_llama_check)

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def sendnn_scaled_bmm(
3232
) -> Tensor:
3333
"""Implement a custom scaled attention BMM op: a batched version of _scaled_mm.
3434
The operations that are part of this function are not exposed to the computational
35-
graph, but are invoked when running on non-AIU devices.
35+
graph, but are invoked when running on non-Spyre devices.
3636
"""
3737

3838
assert (

fms_mo/aiu_addons/fp8/fp8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""FMS registration of attention BMM operation using torch-registered scaled BMM."""
14+
"""Utility functions and components for FP8 addon implementation."""
1515

1616
# Standard
1717
import functools

0 commit comments

Comments
 (0)