Skip to content

Commit 845a7e9

Browse files
feat: update ibm_fms engine to support variant config overrides (IBM#58)
#### Motivation The ibm_fms engine is being too strict on the "variant" selection. The FMS get_model call requires the variant, but also supports overriding configuration parameters via kwargs. #### Modifications Allows for support for additional FMS Calico models that may have differences from the default variant config, such as the size of the vocab. Signed-off-by: Travis Johnson <[email protected]>
1 parent 771d023 commit 845a7e9

File tree

1 file changed

+121
-34
lines changed
  • server/text_generation_server/inference_engine

1 file changed

+121
-34
lines changed
Lines changed: 121 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import torch
22
from typing import Any, Optional, Dict
3+
from loguru import logger
34

45
from fms.models import get_model, list_variants, __models as _fms_models
56
from fms.models.hf import to_hf_api
6-
from fms_extras.models.calico import CalicoConfig
7+
from fms.utils.activation import __ACT_2_CLS as FMS_ACT_2_CLASS
78

89
from text_generation_server.inference_engine.engine import BaseInferenceEngine
910

11+
# Register FMS model classes with HF
1012
from fms_extras.models.hf import register_fms_models
13+
1114
register_fms_models()
1215

16+
1317
class InferenceEngine(BaseInferenceEngine):
1418
def __init__(
1519
self,
@@ -22,60 +26,143 @@ def __init__(
2226
) -> None:
2327
# model_config comes in a as a Dict to support the late registration
2428
if (model_type := model_config["model_type"]) != "gpt_megatron":
25-
raise ValueError(f"Unknown model type {model_type} passed to ibm_fms engine.")
29+
raise ValueError(
30+
f"Unknown model type {model_type} passed to ibm_fms engine."
31+
)
2632

2733
super().__init__(model_path, model_config)
2834

29-
variant_match = self._find_variant(model_config, 'calico')
35+
# only calico supported currently
36+
fms_architecture_name = "calico"
37+
38+
# model_config can override what is set by the variant
39+
fms_config_dict = self._convert_model_config(model_config)
40+
41+
variant_match = self._find_variant(fms_config_dict, fms_architecture_name)
3042
if variant_match is None:
31-
raise ValueError(f"Unable to determine model variant from model: {model_path}")
43+
raise ValueError(
44+
f"Unable to determine model variant for model: {model_path}"
45+
)
3246

3347
# get_model does not have a dtype parameter, setting the default dtype
3448
# reduces memory required to load the model compared to converting afterwards
3549
orig_dtype = torch.get_default_dtype()
3650
try:
3751
torch.set_default_dtype(dtype)
38-
calico_model = get_model("calico", variant_match, model_path, source="megatron", device_type=self.device.type)
39-
40-
# the CalicoConfig does not persist token ids to the
41-
# HFAdaptedConfig, so pass them along explicitly
42-
self.model = to_hf_api(
43-
calico_model,
44-
pad_token_id=model_config['pad_token_id'],
45-
bos_token_id=model_config['bos_token_id'],
46-
eos_token_id=model_config['eos_token_id'],
47-
).requires_grad_(False).eval()
52+
fms_model = get_model(
53+
fms_architecture_name,
54+
variant_match,
55+
model_path,
56+
source="megatron",
57+
device_type=self.device.type,
58+
**fms_config_dict,
59+
)
60+
# get_model does not persist token ids to the HFAdaptedConfig, so
61+
# pass them along explicitly
62+
self.model = (
63+
to_hf_api(
64+
fms_model,
65+
pad_token_id=model_config["pad_token_id"],
66+
bos_token_id=model_config["bos_token_id"],
67+
eos_token_id=model_config["eos_token_id"],
68+
)
69+
.requires_grad_(False)
70+
.eval()
71+
)
4872
finally:
4973
torch.set_default_dtype(orig_dtype)
5074

5175
# update the config to config object instead of a dict
5276
self._config = self.model.config
5377

5478
@classmethod
55-
def _find_variant(cls, model_config_dict: Dict, fms_architecture) -> Optional[str]:
79+
def _convert_model_config(cls, model_config_dict: Dict) -> Dict:
80+
# mapping between CalicoConfig attributes and keys in the model_config dict
81+
act_fn_attr = "activation_fn"
82+
fms_config_attr_to_config_json_key = {
83+
"src_vocab_size": "vocab_size",
84+
"emb_dim": "n_embd",
85+
"norm_eps": "layer_norm_epsilon",
86+
"nheads": "n_head",
87+
"kvheads": "num_key_value_heads",
88+
"nlayers": "n_layer",
89+
"pad_id": "pad_token_id",
90+
# No entry in config.json for hidden_growth_factor
91+
# No entry in config.json for multiple_of
92+
act_fn_attr: "activation_function",
93+
# ignore p_dropout for inference
94+
"max_expected_seq_len": "n_positions",
95+
}
96+
fms_config_dict = {
97+
attr: model_config_dict[key]
98+
for attr, key in fms_config_attr_to_config_json_key.items()
99+
if key in model_config_dict
100+
}
101+
102+
# the activation function name may need to be converted
103+
if act_fn := fms_config_dict.get(act_fn_attr):
104+
fms_config_dict[act_fn_attr] = cls._convert_activation_function_name(act_fn)
105+
106+
return fms_config_dict
107+
108+
@classmethod
109+
def _convert_activation_function_name(cls, act_name: str) -> str:
110+
"""Attempts to find an FMS compatible activation function name
111+
112+
gpt_megatron models may use different names for the activation function
113+
compared to FMS, specifically around whether "GLU" is indicated
114+
explicitly.
115+
116+
Refer to the fms.utils.activation module to see supported names
117+
"""
118+
glu_activation_function_mapping = {
119+
"geglu": "gelu",
120+
"miglu": "mish",
121+
"mishglu": "mish",
122+
"reglu": "relu",
123+
"swiglu": "swish",
124+
}
125+
if act_name.endswith("_glu"):
126+
fms_act_name = act_name.rstrip("_glu")
127+
elif new_name := glu_activation_function_mapping.get(act_name):
128+
fms_act_name = new_name
129+
else:
130+
fms_act_name = act_name
131+
132+
# ensure the final act name is supported by FMS
133+
if fms_act_name not in FMS_ACT_2_CLASS:
134+
raise ValueError(f"Unsupported activation function: {act_name}.")
135+
136+
return fms_act_name
137+
138+
@classmethod
139+
def _find_variant(
140+
cls, fms_config_dict: Dict, fms_architecture: str
141+
) -> Optional[str]:
56142
# get a list of variant configs to compare against the model_config_dict
57143
variant_map = {
58144
# HACK: extract the variant config from the closure created for the factory functions...
59145
v: _fms_models[fms_architecture][v].__closure__[0].cell_contents
60146
for v in list_variants(fms_architecture)
61147
}
62-
for v, v_config in variant_map.items():
63-
if cls._is_variant_compatible(model_config_dict, v_config):
64-
return v
65-
return None
66148

67-
@classmethod
68-
def _is_variant_compatible(cls, model_config_dict: Dict, config: CalicoConfig) -> bool:
69-
dict_key_to_attr = {
70-
'vocab_size': 'src_vocab_size',
71-
'n_embd': 'emb_dim',
72-
'n_head': 'nheads',
73-
'num_key_value_heads': 'kvheads',
74-
'n_layer': 'nlayers',
75-
'n_positions': 'max_expected_seq_len',
76-
'pad_token_id': 'pad_id',
77-
}
78-
for key, attr in dict_key_to_attr.items():
79-
if model_config_dict[key] != getattr(config, attr, None):
80-
return False
81-
return True
149+
# attributes of the CalicoConfig that must exist and match to find a
150+
# compatible "variant"
151+
variant_attrs_to_check = [
152+
"emb_dim",
153+
"nheads",
154+
"kvheads",
155+
"nlayers",
156+
]
157+
if not all(fms_config_dict.get(attr, None) for attr in variant_attrs_to_check):
158+
raise ValueError(
159+
f"Unable to find compatible variant, the following configurations must exist {variant_attrs_to_check}"
160+
)
161+
162+
for v_name, v_config in variant_map.items():
163+
if all(
164+
fms_config_dict.get(attr) == getattr(v_config, attr)
165+
for attr in variant_attrs_to_check
166+
):
167+
return v_name
168+
return None

0 commit comments

Comments
 (0)