1
1
import torch
2
2
from typing import Any , Optional , Dict
3
+ from loguru import logger
3
4
4
5
from fms .models import get_model , list_variants , __models as _fms_models
5
6
from fms .models .hf import to_hf_api
6
- from fms_extras . models . calico import CalicoConfig
7
+ from fms . utils . activation import __ACT_2_CLS as FMS_ACT_2_CLASS
7
8
8
9
from text_generation_server .inference_engine .engine import BaseInferenceEngine
9
10
11
+ # Register FMS model classes with HF
10
12
from fms_extras .models .hf import register_fms_models
13
+
11
14
register_fms_models ()
12
15
16
+
13
17
class InferenceEngine (BaseInferenceEngine ):
14
18
def __init__ (
15
19
self ,
@@ -22,60 +26,143 @@ def __init__(
22
26
) -> None :
23
27
# model_config comes in a as a Dict to support the late registration
24
28
if (model_type := model_config ["model_type" ]) != "gpt_megatron" :
25
- raise ValueError (f"Unknown model type { model_type } passed to ibm_fms engine." )
29
+ raise ValueError (
30
+ f"Unknown model type { model_type } passed to ibm_fms engine."
31
+ )
26
32
27
33
super ().__init__ (model_path , model_config )
28
34
29
- variant_match = self ._find_variant (model_config , 'calico' )
35
+ # only calico supported currently
36
+ fms_architecture_name = "calico"
37
+
38
+ # model_config can override what is set by the variant
39
+ fms_config_dict = self ._convert_model_config (model_config )
40
+
41
+ variant_match = self ._find_variant (fms_config_dict , fms_architecture_name )
30
42
if variant_match is None :
31
- raise ValueError (f"Unable to determine model variant from model: { model_path } " )
43
+ raise ValueError (
44
+ f"Unable to determine model variant for model: { model_path } "
45
+ )
32
46
33
47
# get_model does not have a dtype parameter, setting the default dtype
34
48
# reduces memory required to load the model compared to converting afterwards
35
49
orig_dtype = torch .get_default_dtype ()
36
50
try :
37
51
torch .set_default_dtype (dtype )
38
- calico_model = get_model ("calico" , variant_match , model_path , source = "megatron" , device_type = self .device .type )
39
-
40
- # the CalicoConfig does not persist token ids to the
41
- # HFAdaptedConfig, so pass them along explicitly
42
- self .model = to_hf_api (
43
- calico_model ,
44
- pad_token_id = model_config ['pad_token_id' ],
45
- bos_token_id = model_config ['bos_token_id' ],
46
- eos_token_id = model_config ['eos_token_id' ],
47
- ).requires_grad_ (False ).eval ()
52
+ fms_model = get_model (
53
+ fms_architecture_name ,
54
+ variant_match ,
55
+ model_path ,
56
+ source = "megatron" ,
57
+ device_type = self .device .type ,
58
+ ** fms_config_dict ,
59
+ )
60
+ # get_model does not persist token ids to the HFAdaptedConfig, so
61
+ # pass them along explicitly
62
+ self .model = (
63
+ to_hf_api (
64
+ fms_model ,
65
+ pad_token_id = model_config ["pad_token_id" ],
66
+ bos_token_id = model_config ["bos_token_id" ],
67
+ eos_token_id = model_config ["eos_token_id" ],
68
+ )
69
+ .requires_grad_ (False )
70
+ .eval ()
71
+ )
48
72
finally :
49
73
torch .set_default_dtype (orig_dtype )
50
74
51
75
# update the config to config object instead of a dict
52
76
self ._config = self .model .config
53
77
54
78
@classmethod
55
- def _find_variant (cls , model_config_dict : Dict , fms_architecture ) -> Optional [str ]:
79
+ def _convert_model_config (cls , model_config_dict : Dict ) -> Dict :
80
+ # mapping between CalicoConfig attributes and keys in the model_config dict
81
+ act_fn_attr = "activation_fn"
82
+ fms_config_attr_to_config_json_key = {
83
+ "src_vocab_size" : "vocab_size" ,
84
+ "emb_dim" : "n_embd" ,
85
+ "norm_eps" : "layer_norm_epsilon" ,
86
+ "nheads" : "n_head" ,
87
+ "kvheads" : "num_key_value_heads" ,
88
+ "nlayers" : "n_layer" ,
89
+ "pad_id" : "pad_token_id" ,
90
+ # No entry in config.json for hidden_growth_factor
91
+ # No entry in config.json for multiple_of
92
+ act_fn_attr : "activation_function" ,
93
+ # ignore p_dropout for inference
94
+ "max_expected_seq_len" : "n_positions" ,
95
+ }
96
+ fms_config_dict = {
97
+ attr : model_config_dict [key ]
98
+ for attr , key in fms_config_attr_to_config_json_key .items ()
99
+ if key in model_config_dict
100
+ }
101
+
102
+ # the activation function name may need to be converted
103
+ if act_fn := fms_config_dict .get (act_fn_attr ):
104
+ fms_config_dict [act_fn_attr ] = cls ._convert_activation_function_name (act_fn )
105
+
106
+ return fms_config_dict
107
+
108
+ @classmethod
109
+ def _convert_activation_function_name (cls , act_name : str ) -> str :
110
+ """Attempts to find an FMS compatible activation function name
111
+
112
+ gpt_megatron models may use different names for the activation function
113
+ compared to FMS, specifically around whether "GLU" is indicated
114
+ explicitly.
115
+
116
+ Refer to the fms.utils.activation module to see supported names
117
+ """
118
+ glu_activation_function_mapping = {
119
+ "geglu" : "gelu" ,
120
+ "miglu" : "mish" ,
121
+ "mishglu" : "mish" ,
122
+ "reglu" : "relu" ,
123
+ "swiglu" : "swish" ,
124
+ }
125
+ if act_name .endswith ("_glu" ):
126
+ fms_act_name = act_name .rstrip ("_glu" )
127
+ elif new_name := glu_activation_function_mapping .get (act_name ):
128
+ fms_act_name = new_name
129
+ else :
130
+ fms_act_name = act_name
131
+
132
+ # ensure the final act name is supported by FMS
133
+ if fms_act_name not in FMS_ACT_2_CLASS :
134
+ raise ValueError (f"Unsupported activation function: { act_name } ." )
135
+
136
+ return fms_act_name
137
+
138
+ @classmethod
139
+ def _find_variant (
140
+ cls , fms_config_dict : Dict , fms_architecture : str
141
+ ) -> Optional [str ]:
56
142
# get a list of variant configs to compare against the model_config_dict
57
143
variant_map = {
58
144
# HACK: extract the variant config from the closure created for the factory functions...
59
145
v : _fms_models [fms_architecture ][v ].__closure__ [0 ].cell_contents
60
146
for v in list_variants (fms_architecture )
61
147
}
62
- for v , v_config in variant_map .items ():
63
- if cls ._is_variant_compatible (model_config_dict , v_config ):
64
- return v
65
- return None
66
148
67
- @classmethod
68
- def _is_variant_compatible (cls , model_config_dict : Dict , config : CalicoConfig ) -> bool :
69
- dict_key_to_attr = {
70
- 'vocab_size' : 'src_vocab_size' ,
71
- 'n_embd' : 'emb_dim' ,
72
- 'n_head' : 'nheads' ,
73
- 'num_key_value_heads' : 'kvheads' ,
74
- 'n_layer' : 'nlayers' ,
75
- 'n_positions' : 'max_expected_seq_len' ,
76
- 'pad_token_id' : 'pad_id' ,
77
- }
78
- for key , attr in dict_key_to_attr .items ():
79
- if model_config_dict [key ] != getattr (config , attr , None ):
80
- return False
81
- return True
149
+ # attributes of the CalicoConfig that must exist and match to find a
150
+ # compatible "variant"
151
+ variant_attrs_to_check = [
152
+ "emb_dim" ,
153
+ "nheads" ,
154
+ "kvheads" ,
155
+ "nlayers" ,
156
+ ]
157
+ if not all (fms_config_dict .get (attr , None ) for attr in variant_attrs_to_check ):
158
+ raise ValueError (
159
+ f"Unable to find compatible variant, the following configurations must exist { variant_attrs_to_check } "
160
+ )
161
+
162
+ for v_name , v_config in variant_map .items ():
163
+ if all (
164
+ fms_config_dict .get (attr ) == getattr (v_config , attr )
165
+ for attr in variant_attrs_to_check
166
+ ):
167
+ return v_name
168
+ return None
0 commit comments