Skip to content

Commit bdf1cf2

Browse files
committed
Gate FMS imports
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent 6a88117 commit bdf1cf2

File tree

3 files changed

+299
-294
lines changed

3 files changed

+299
-294
lines changed

fms_mo/aiu_addons/fp8/fp8_adapter.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,54 +17,57 @@
1717
from typing import Any, Mapping
1818
import functools
1919

20-
# Third Party
21-
from fms.modules.linear import get_linear_type
22-
from fms.utils import serialization
23-
from fms.utils.config import ModelConfig
20+
# Local
21+
from fms_mo.prep import available_packages
2422

25-
# pylint: disable=unused-argument
26-
# Retaining kwargs input arguments for consistency with other adapter steps.
23+
if available_packages["fms"]:
24+
# Third Party
25+
from fms.modules.linear import get_linear_type
26+
from fms.utils import serialization
27+
from fms.utils.config import ModelConfig
2728

29+
# pylint: disable=unused-argument
30+
# Retaining kwargs input arguments for consistency with other adapter steps.
31+
# TODO: may be shared with gptq llama
32+
def _hf_fp8_check(
33+
input_sd: Mapping[str, Any],
34+
model_config: ModelConfig | None = None,
35+
checkpoint_is_fused: bool = False,
36+
**kwargs,
37+
) -> Mapping[str, Any]:
38+
"""Implementation of adapter step for FMS: ensure that when FP8 quantization
39+
is in use, weights are fused like the model checkpoint.
40+
"""
2841

29-
# TODO: may be shared with gptq llama
30-
def _hf_fp8_check(
31-
input_sd: Mapping[str, Any],
32-
model_config: ModelConfig | None = None,
33-
checkpoint_is_fused: bool = False,
34-
**kwargs,
35-
) -> Mapping[str, Any]:
36-
"""Implementation of adapter step for FMS: ensure that when FP8 quantization
37-
is in use, weights are fused like the model checkpoint.
38-
"""
42+
has_fused_weights = True
43+
linear_type = "torch_linear"
44+
if model_config:
45+
if not model_config.fused_weights:
46+
has_fused_weights = False
47+
if model_config.linear_config:
48+
linear_type = model_config.linear_config["linear_type"]
49+
if callable(linear_type):
50+
# Calling this function with "any" guarantees "fp8" to be returned
51+
# when loading an HF fp8 checkpoint, and never in any other condition
52+
linear_type = get_linear_type(model_config.linear_config, "any")
3953

40-
has_fused_weights = True
41-
linear_type = "torch_linear"
42-
if model_config:
43-
if not model_config.fused_weights:
44-
has_fused_weights = False
45-
if model_config.linear_config:
46-
linear_type = model_config.linear_config["linear_type"]
47-
if callable(linear_type):
48-
# Calling this function with "any" guarantees "fp8" to be returned
49-
# when loading an HF fp8 checkpoint, and never in any other condition
50-
linear_type = get_linear_type(model_config.linear_config, "any")
54+
if "fp8" in linear_type and has_fused_weights != checkpoint_is_fused:
55+
raise ValueError(
56+
"FP8 HF llama checkpoints cannot be loaded into a model with fused weights"
57+
)
5158

52-
if "fp8" in linear_type and has_fused_weights != checkpoint_is_fused:
53-
raise ValueError(
54-
"FP8 HF llama checkpoints cannot be loaded into a model with fused weights"
55-
)
59+
return input_sd
5660

57-
return input_sd
61+
serialization.register_adapter_step(
62+
"llama",
63+
"hf_fp8_check",
64+
functools.partial(_hf_fp8_check, checkpoint_is_fused=False),
65+
)
66+
serialization.extend_adapter("llama", "hf", ["hf_fp8_check"])
5867

59-
60-
serialization.register_adapter_step(
61-
"llama", "hf_fp8_check", functools.partial(_hf_fp8_check, checkpoint_is_fused=False)
62-
)
63-
serialization.extend_adapter("llama", "hf", ["hf_fp8_check"])
64-
65-
serialization.register_adapter_step(
66-
"granite",
67-
"hf_fp8_check",
68-
functools.partial(_hf_fp8_check, checkpoint_is_fused=False),
69-
)
70-
serialization.extend_adapter("granite", "hf", ["hf_fp8_check"])
68+
serialization.register_adapter_step(
69+
"granite",
70+
"hf_fp8_check",
71+
functools.partial(_hf_fp8_check, checkpoint_is_fused=False),
72+
)
73+
serialization.extend_adapter("granite", "hf", ["hf_fp8_check"])

0 commit comments

Comments
 (0)