@@ -97,41 +97,22 @@ def _add_defaults_and_concat(
9797 )
9898
9999
100- # registration of new adapter steps for each architecture
101- serialization .register_adapter_step ("llama" , "int8_qparams_aiu" , _int8_qparams_aiu )
102- serialization .register_adapter_step (
103- "gpt_bigcode" , "int8_qparams_aiu" , _int8_qparams_aiu
104- )
105- serialization .register_adapter_step ("roberta" , "int8_qparams_aiu" , _int8_qparams_aiu )
106- serialization .register_adapter_step (
107- "roberta_question_answering" ,
108- "int8_qparams_aiu" ,
109- _int8_qparams_aiu ,
110- )
111-
112- # registration of multi-step adapter for each architecture
113- serialization .register_adapter (
100+ # registration of new adapter step and adapter for each architecture
101+ for arch in [
114102 "llama" ,
115- "fms_mo" ,
116- [
117- "hf_to_fms_names" ,
118- "hf_to_fms_rope" ,
119- "weight_fusion" ,
120- "int8_qparams_aiu" ,
121- ],
122- )
123- serialization .register_adapter (
124- "gpt_bigcode" , "fms_mo" , ["hf_to_fms_names" , "weight_fusion" , "int8_qparams_aiu" ]
125- )
126- serialization .register_adapter (
127- "roberta" , "fms_mo" , ["hf_to_fms_names" , "weight_fusion" , "int8_qparams_aiu" ]
128- )
129- serialization .register_adapter (
103+ "gpt_bigcode" ,
104+ "granite" ,
105+ "roberta" ,
130106 "roberta_question_answering" ,
131- "fms_mo" ,
132- [
133- "hf_to_fms_names" ,
134- "weight_fusion" ,
135- "int8_qparams_aiu" ,
136- ],
137- )
107+ ]:
108+ serialization .register_adapter_step (arch , "int8_qparams_aiu" , _int8_qparams_aiu )
109+ if arch in ["llama" , "granite" ]:
110+ steps_to_register = [
111+ "hf_to_fms_names" ,
112+ "hf_to_fms_rope" ,
113+ "weight_fusion" ,
114+ "int8_qparams_aiu" ,
115+ ]
116+ else :
117+ steps_to_register = ["hf_to_fms_names" , "weight_fusion" , "int8_qparams_aiu" ]
118+ serialization .register_adapter (arch , "fms_mo" , steps_to_register )
0 commit comments