Skip to content

Commit 5eb4a8d

Browse files
Transformation code cleanup
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 77495f5 commit 5eb4a8d

File tree

7 files changed

+260
-243
lines changed

7 files changed

+260
-243
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
157157
"If False, auto-detect and use column+row (all_reduce) sharding when possible.",
158158
)
159159

160-
use_sharding_from_config: bool = Field(
161-
default=True,
160+
use_sharding_from_factory: bool = Field(
161+
default=False,
162162
description="If True, use sharding from the model config (if present). "
163163
"If False, run heuristics to detect sharding.",
164164
)

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ def get_quant_config(self) -> Dict:
9696
"""Returns the quantization config for this model or None if not quantized."""
9797
return {}
9898

99+
def get_sharding_config(self):
100+
"""Returns the sharding config for this model or None if not sharded."""
101+
return {}
102+
99103
def get_cache_config(self) -> CacheConfig:
100104
"""Return the cache configuration for the model.
101105

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,16 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
175175
model.post_init()
176176

177177
# if present, initialize sharding config. We need head_dim for colwise sharding.
178+
self._set_sharding_config(model_config)
179+
180+
# patch forward method
181+
model.forward = types.MethodType(self._simple_forward, model)
182+
183+
model.eval()
184+
return model
185+
186+
def _set_sharding_config(self, model_config: PretrainedConfig):
187+
"""Set the sharding config for the model."""
178188
self._sharding_config = {}
179189
self._sharding_config["head_dim"] = 1
180190
if hasattr(model_config, "base_model_tp_plan"):
@@ -183,30 +193,6 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
183193
self._sharding_config["head_dim"] = model_config.head_dim
184194
if hasattr(model_config, "num_hidden_layers"):
185195
self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers
186-
# if it is a multi-modal factory, overwrite the sharding config with the
187-
# dedicated sub-configs
188-
if hasattr(model_config, "sub_configs") and len(model_config.sub_configs) > 0:
189-
# for image-text-to-text models, we only support sharding for the text sub-config
190-
if isinstance(self, AutoModelForImageTextToTextFactory):
191-
text_config = model_config.sub_configs["text_config"]
192-
# if text_config is a class, instantiate it
193-
if isinstance(text_config, type):
194-
text_config = text_config()
195-
if hasattr(text_config, "base_model_tp_plan"):
196-
self._sharding_config["tp_plan"] = text_config.base_model_tp_plan
197-
if hasattr(text_config, "head_dim"):
198-
self._sharding_config["head_dim"] = text_config.head_dim
199-
if hasattr(text_config, "num_hidden_layers"):
200-
self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers
201-
else:
202-
# TODO: support sharding for other multi-modal models
203-
pass
204-
205-
# patch forward method
206-
model.forward = types.MethodType(self._simple_forward, model)
207-
208-
model.eval()
209-
return model
210196

211197
def get_sharding_config(self):
212198
return self._sharding_config or {}
@@ -394,6 +380,20 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
394380
},
395381
}
396382

383+
def _set_sharding_config(self, model_config: PretrainedConfig):
384+
"""Set the sharding config for the model."""
385+
self._sharding_config = {}
386+
text_config = model_config.sub_configs["text_config"]
387+
# if text_config is a class, instantiate it
388+
if isinstance(text_config, type):
389+
text_config = text_config()
390+
if hasattr(text_config, "base_model_tp_plan"):
391+
self._sharding_config["tp_plan"] = text_config.base_model_tp_plan
392+
if hasattr(text_config, "head_dim"):
393+
self._sharding_config["head_dim"] = text_config.head_dim
394+
if hasattr(text_config, "num_hidden_layers"):
395+
self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers
396+
397397
@property
398398
def automodel_from_config(self):
399399
return AutoModelForImageTextToText.from_config

0 commit comments

Comments
 (0)