@@ -175,6 +175,16 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
175175 model .post_init ()
176176
177177 # if present, initialize sharding config. We need head_dim for colwise sharding.
178+ self ._set_sharding_config (model_config )
179+
180+ # patch forward method
181+ model .forward = types .MethodType (self ._simple_forward , model )
182+
183+ model .eval ()
184+ return model
185+
186+ def _set_sharding_config (self , model_config : PretrainedConfig ):
187+ """Set the sharding config for the model."""
178188 self ._sharding_config = {}
179189 self ._sharding_config ["head_dim" ] = 1
180190 if hasattr (model_config , "base_model_tp_plan" ):
@@ -183,30 +193,6 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
183193 self ._sharding_config ["head_dim" ] = model_config .head_dim
184194 if hasattr (model_config , "num_hidden_layers" ):
185195 self ._sharding_config ["num_hidden_layers" ] = model_config .num_hidden_layers
186- # if it is a multi-modal factory, overwrite the sharding config with the
187- # dedicated sub-configs
188- if hasattr (model_config , "sub_configs" ) and len (model_config .sub_configs ) > 0 :
189- # for image-text-to-text models, we only support sharding for the text sub-config
190- if isinstance (self , AutoModelForImageTextToTextFactory ):
191- text_config = model_config .sub_configs ["text_config" ]
192- # if text_config is a class, instantiate it
193- if isinstance (text_config , type ):
194- text_config = text_config ()
195- if hasattr (text_config , "base_model_tp_plan" ):
196- self ._sharding_config ["tp_plan" ] = text_config .base_model_tp_plan
197- if hasattr (text_config , "head_dim" ):
198- self ._sharding_config ["head_dim" ] = text_config .head_dim
199- if hasattr (text_config , "num_hidden_layers" ):
200- self ._sharding_config ["num_hidden_layers" ] = text_config .num_hidden_layers
201- else :
202- # TODO: support sharding for other multi-modal models
203- pass
204-
205- # patch forward method
206- model .forward = types .MethodType (self ._simple_forward , model )
207-
208- model .eval ()
209- return model
210196
211197 def get_sharding_config (self ):
212198 return self ._sharding_config or {}
@@ -394,6 +380,20 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
394380 },
395381 }
396382
383+ def _set_sharding_config (self , model_config : PretrainedConfig ):
384+ """Set the sharding config for the model."""
385+ self ._sharding_config = {}
386+ text_config = model_config .sub_configs ["text_config" ]
387+ # if text_config is a class, instantiate it
388+ if isinstance (text_config , type ):
389+ text_config = text_config ()
390+ if hasattr (text_config , "base_model_tp_plan" ):
391+ self ._sharding_config ["tp_plan" ] = text_config .base_model_tp_plan
392+ if hasattr (text_config , "head_dim" ):
393+ self ._sharding_config ["head_dim" ] = text_config .head_dim
394+ if hasattr (text_config , "num_hidden_layers" ):
395+ self ._sharding_config ["num_hidden_layers" ] = text_config .num_hidden_layers
396+
397397 @property
398398 def automodel_from_config (self ):
399399 return AutoModelForImageTextToText .from_config
0 commit comments