Skip to content

Commit 59901ba

Browse files
Cleanup sharding interface
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 4c6dcc1 commit 59901ba

File tree

4 files changed

+57
-48
lines changed

4 files changed

+57
-48
lines changed

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
from ..custom_ops.attention_interface import CacheConfig
1313
from ..utils.logger import ad_logger
1414

15-
5
1615

17-
18-
class FactorySource(Enum):
16+
class ShardingConfigSource(Enum):
1917
"""Enum for factory source."""
2018

2119
HUGGINGFACE = "huggingface"
@@ -48,6 +46,7 @@ def __init__(
4846
self.max_seq_len = max_seq_len
4947
self._prefetched_model_path: Optional[str] = None
5048
self._prefetched_tokenizer_path: Optional[str] = None
49+
self._sharding_config: Dict[str, Any] = {}
5150

5251
@property
5352
def model(self) -> Optional[str]:
@@ -106,9 +105,9 @@ def get_quant_config(self) -> Dict:
106105
"""Returns the quantization config for this model or None if not quantized."""
107106
return {}
108107

109-
def get_sharding_config(self):
110-
"""Returns the sharding config for this model or None if not sharded."""
111-
return {}
108+
def get_sharding_config(self) -> Dict:
109+
"""Returns the sharding config for this model."""
110+
return self._sharding_config
112111

113112
def get_cache_config(self) -> CacheConfig:
114113
"""Return the cache configuration for the model.
@@ -118,13 +117,13 @@ def get_cache_config(self) -> CacheConfig:
118117
"""
119118
return CacheConfig()
120119

121-
def get_model_source(self) -> FactorySource:
120+
def get_sharding_config_source(self) -> ShardingConfigSource:
122121
"""Return the source of the model factory.
123122
124123
Returns:
125124
The source identifier for this model factory.
126125
"""
127-
return FactorySource.UNKNOWN
126+
return ShardingConfigSource.UNKNOWN
128127

129128
def init_tokenizer(self) -> Optional[Any]:
130129
"""Initialize the tokenizer for the model.

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..custom_ops.attention_interface import CacheConfig
3131
from ..utils._config import deep_merge_dicts
3232
from ..utils.logger import ad_logger
33-
from .factory import FactorySource, ModelFactory, ModelFactoryRegistry
33+
from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource
3434

3535

3636
@contextmanager
@@ -175,7 +175,7 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
175175
model.post_init()
176176

177177
# if present, initialize sharding config. We need head_dim for colwise sharding.
178-
self._set_sharding_config(model_config)
178+
self._set_sharding_config(model.config)
179179

180180
# patch forward method
181181
model.forward = types.MethodType(self._simple_forward, model)
@@ -185,7 +185,6 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
185185

186186
def _set_sharding_config(self, model_config: PretrainedConfig):
187187
"""Set the sharding config for the model."""
188-
self._sharding_config = {}
189188
self._sharding_config["head_dim"] = 1
190189
if hasattr(model_config, "base_model_tp_plan"):
191190
self._sharding_config["tp_plan"] = model_config.base_model_tp_plan
@@ -194,9 +193,6 @@ def _set_sharding_config(self, model_config: PretrainedConfig):
194193
if hasattr(model_config, "num_hidden_layers"):
195194
self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers
196195

197-
def get_sharding_config(self):
198-
return self._sharding_config or {}
199-
200196
def get_quant_config(self) -> Dict:
201197
return self._quant_config or {}
202198

@@ -213,13 +209,13 @@ def get_cache_config(self):
213209
kv_cache_dtype = None
214210
return CacheConfig(dtype=kv_cache_dtype)
215211

216-
def get_model_source(self) -> FactorySource:
212+
def get_sharding_config_source(self) -> ShardingConfigSource:
217213
"""Return the source of the model factory.
218214
219215
Returns:
220216
The source identifier for this model factory.
221217
"""
222-
return FactorySource.HUGGINGFACE
218+
return ShardingConfigSource.HUGGINGFACE
223219

224220
def init_tokenizer(self) -> Optional[Any]:
225221
"""Initialize the tokenizer—either a custom name or the model's default."""
@@ -389,18 +385,17 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
389385
}
390386

391387
def _set_sharding_config(self, model_config: PretrainedConfig):
392-
"""Set the sharding config for the model."""
393-
self._sharding_config = {}
394-
text_config = model_config.sub_configs["text_config"]
395-
# if text_config is a class, instantiate it
396-
if isinstance(text_config, type):
397-
text_config = text_config()
398-
if hasattr(text_config, "base_model_tp_plan"):
399-
self._sharding_config["tp_plan"] = text_config.base_model_tp_plan
400-
if hasattr(text_config, "head_dim"):
401-
self._sharding_config["head_dim"] = text_config.head_dim
402-
if hasattr(text_config, "num_hidden_layers"):
403-
self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers
388+
"""Override the sharding config for the model with text_config."""
389+
super()._set_sharding_config(model_config)
390+
391+
if hasattr(model_config, "text_config"):
392+
text_config = model_config.text_config
393+
if hasattr(text_config, "base_model_tp_plan"):
394+
self._sharding_config["tp_plan"] = text_config.base_model_tp_plan
395+
if hasattr(text_config, "head_dim"):
396+
self._sharding_config["head_dim"] = text_config.head_dim
397+
if hasattr(text_config, "num_hidden_layers"):
398+
self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers
404399

405400
@property
406401
def automodel_from_config(self):

tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pydantic import BaseModel, ConfigDict, Field
3131
from torch.fx import GraphModule, Node
3232

33-
from ...models.factory import FactorySource
33+
from ...models.factory import ModelFactory, ShardingConfigSource
3434
from ...utils.logger import ad_logger
3535
from ...utils.node_utils import (
3636
extract_param_names_from_lin_node,
@@ -255,7 +255,7 @@ def apply(self, gm: GraphModule, node: Node) -> None:
255255
class ShardingConfig(BaseModel):
256256
"""Configuration for sharding the model."""
257257

258-
factory_source: FactorySource
258+
factory_source: ShardingConfigSource
259259
rank: int
260260
world_size: int
261261
_predefined_config: Optional[Dict[str, Any]] = None
@@ -267,7 +267,7 @@ class ShardingConfig(BaseModel):
267267

268268
def __init__(
269269
self,
270-
factory_source: FactorySource,
270+
factory_source: ShardingConfigSource,
271271
rank: int,
272272
world_size: int,
273273
sharding_config: Dict[str, Any] = None,
@@ -290,30 +290,30 @@ def __init__(
290290
self.validate_config()
291291

292292
def validate_config(self) -> bool:
293-
if self.factory_source != FactorySource.HUGGINGFACE:
293+
if self.factory_source != ShardingConfigSource.HUGGINGFACE:
294294
ad_logger.warning(
295295
"Sharding config is is currently only " + "supported for HuggingFace. Skipping."
296296
)
297297
# invalidate the config
298-
self._predefined_config = None
298+
self._predefined_config = {}
299299
return False
300300

301301
if not isinstance(self._predefined_config, dict):
302302
ad_logger.warning("Sharding config is not a dictionary. Skipping.")
303303
# invalidate the config
304-
self._predefined_config = None
304+
self._predefined_config = {}
305305
return False
306306

307307
if "head_dim" not in self._predefined_config:
308308
ad_logger.warning("Sharding config does not contain head_dim. Skipping.")
309309
# invalidate the config
310-
self._predefined_config = None
310+
self._predefined_config = {}
311311
return False
312312

313313
if "tp_plan" not in self._predefined_config:
314314
ad_logger.warning("Sharding config does not contain tp_plan. Skipping.")
315315
# invalidate the config
316-
self._predefined_config = None
316+
self._predefined_config = {}
317317
return False
318318
tp_plan = self._predefined_config["tp_plan"]
319319

@@ -333,7 +333,7 @@ def validate_config(self) -> bool:
333333
if not values.issubset(allowed_values):
334334
ad_logger.warning("Sharding config contains invalid values. Skipping.")
335335
# invalidate the config
336-
self._predefined_config = None
336+
self._predefined_config = {}
337337
return False
338338
return True
339339

@@ -727,10 +727,26 @@ def _append_simple_shard(
727727
sharding_config.tp_transforms.extend(tp_shards)
728728

729729

730-
def detect_sharding(gm: GraphModule, sharding_config: ShardingConfig) -> None:
730+
def detect_sharding(
731+
gm: GraphModule,
732+
factory: ModelFactory,
733+
local_rank: int,
734+
world_size: int,
735+
simple_shard_only: bool,
736+
use_sharding_from_factory: bool,
737+
) -> ShardingConfig:
738+
sharding_config = ShardingConfig(
739+
factory.get_sharding_config_source(),
740+
local_rank,
741+
world_size,
742+
factory.get_sharding_config(),
743+
simple_shard_only,
744+
use_sharding_from_factory,
745+
)
746+
731747
if (
732748
sharding_config.use_sharding_from_factory
733-
and sharding_config.get_predefined_config() is not None
749+
and len(sharding_config.get_predefined_config()) > 0
734750
):
735751
ad_logger.info("Applying sharding from config")
736752
detect_sharding_from_factory_config(gm, sharding_config)
@@ -746,6 +762,8 @@ def detect_sharding(gm: GraphModule, sharding_config: ShardingConfig) -> None:
746762
# run BMM sharding across ranks
747763
detect_dp_bmm_shard(gm, sharding_config)
748764

765+
return sharding_config
766+
749767

750768
def detect_column_row_shard(
751769
gm: GraphModule,
@@ -771,7 +789,7 @@ def detect_column_row_shard(
771789

772790
rank, world_size = sharding_config.rank, sharding_config.world_size
773791
if world_size < 2:
774-
ad_logger.info("Skipping sharding for single device")
792+
ad_logger.info("Skipping TP sharding for single device")
775793
return
776794

777795
assert isinstance(gm, GraphModule), "Expecting GraphModule"
@@ -937,7 +955,7 @@ def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Non
937955
ad_logger.debug("Before sharding graph: " + str(gm))
938956
rank, world_size = sharding_config.rank, sharding_config.world_size
939957
if world_size < 2:
940-
ad_logger.info("Skipping sharding for single device")
958+
ad_logger.info("Skipping DP BMM sharding for single device")
941959
return
942960

943961
assert isinstance(gm, GraphModule), "Expecting GraphModule"
@@ -1008,7 +1026,7 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> None:
10081026

10091027
rank, world_size = sharding_config.rank, sharding_config.world_size
10101028
if world_size < 2:
1011-
ad_logger.info("Skipping sharding for single device")
1029+
ad_logger.info("Skipping EP sharding for single device")
10121030
return
10131031

10141032
assert isinstance(gm, GraphModule), "Expecting GraphModule"

tensorrt_llm/_torch/auto_deploy/transformations/transform.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from ..utils.logger import ad_logger
1616
from ._graph import canonicalize_graph, lift_to_meta, move_to_device
1717
from .library import (
18-
ShardingConfig,
1918
detect_sharding,
2019
eliminate_redundant_transposes,
2120
fuse_allreduce_residual_rmsnorm,
@@ -112,17 +111,15 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module:
112111
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
113112
optimize_rope(egm)
114113

115-
sharding_config = ShardingConfig(
116-
self.factory.get_model_source(),
114+
sharding_config = detect_sharding(
115+
egm,
116+
self.factory,
117117
local_rank,
118118
world_size,
119-
self.factory.get_sharding_config(),
120119
self.ad_config.simple_shard_only,
121120
self.ad_config.use_sharding_from_factory,
122121
)
123122

124-
detect_sharding(egm, sharding_config)
125-
126123
sharding_transform_executor(egm, sharding_config)
127124

128125
# let's run a shape propagation pass to update the graph with correct meta values for

0 commit comments

Comments
 (0)