Skip to content

Commit 4c6dcc1

Browse files
Changed sharding config logic
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent a891808 commit 4c6dcc1

File tree

4 files changed

+148
-75
lines changed

4 files changed

+148
-75
lines changed

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import copy
44
from abc import ABC, abstractmethod
5+
from enum import Enum
56
from typing import Any, Callable, Dict, Optional, Type
67

78
import torch
@@ -11,6 +12,15 @@
1112
from ..custom_ops.attention_interface import CacheConfig
1213
from ..utils.logger import ad_logger
1314

15+
5
16+
17+
18+
class FactorySource(Enum):
19+
"""Enum for factory source."""
20+
21+
HUGGINGFACE = "huggingface"
22+
UNKNOWN = "unknown"
23+
1424

1525
class ModelFactory(ABC):
1626
"""An interface to return and correctly initialize a model from a desired source.
@@ -108,6 +118,14 @@ def get_cache_config(self) -> CacheConfig:
108118
"""
109119
return CacheConfig()
110120

121+
def get_model_source(self) -> FactorySource:
122+
"""Return the source of the model factory.
123+
124+
Returns:
125+
The source identifier for this model factory.
126+
"""
127+
return FactorySource.UNKNOWN
128+
111129
def init_tokenizer(self) -> Optional[Any]:
112130
"""Initialize the tokenizer for the model.
113131

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..custom_ops.attention_interface import CacheConfig
3131
from ..utils._config import deep_merge_dicts
3232
from ..utils.logger import ad_logger
33-
from .factory import ModelFactory, ModelFactoryRegistry
33+
from .factory import FactorySource, ModelFactory, ModelFactoryRegistry
3434

3535

3636
@contextmanager
@@ -213,6 +213,14 @@ def get_cache_config(self):
213213
kv_cache_dtype = None
214214
return CacheConfig(dtype=kv_cache_dtype)
215215

216+
def get_model_source(self) -> FactorySource:
217+
"""Return the source of the model factory.
218+
219+
Returns:
220+
The source identifier for this model factory.
221+
"""
222+
return FactorySource.HUGGINGFACE
223+
216224
def init_tokenizer(self) -> Optional[Any]:
217225
"""Initialize the tokenizer—either a custom name or the model's default."""
218226
if self.tokenizer is None:

tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py

Lines changed: 118 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pydantic import BaseModel, ConfigDict, Field
3131
from torch.fx import GraphModule, Node
3232

33+
from ...models.factory import FactorySource
3334
from ...utils.logger import ad_logger
3435
from ...utils.node_utils import (
3536
extract_param_names_from_lin_node,
@@ -254,9 +255,10 @@ def apply(self, gm: GraphModule, node: Node) -> None:
254255
class ShardingConfig(BaseModel):
255256
"""Configuration for sharding the model."""
256257

257-
rank: int = 0
258-
world_size: int = 1
259-
predefined_config: Dict[str, Any] = None
258+
factory_source: FactorySource
259+
rank: int
260+
world_size: int
261+
_predefined_config: Optional[Dict[str, Any]] = None
260262
simple_shard_only: bool = False
261263
use_sharding_from_factory: bool = False
262264
tp_transforms: List[TPShardingInfo] = Field(default_factory=list)
@@ -265,21 +267,81 @@ class ShardingConfig(BaseModel):
265267

266268
def __init__(
267269
self,
270+
factory_source: FactorySource,
268271
rank: int,
269272
world_size: int,
270273
sharding_config: Dict[str, Any] = None,
271274
simple_shard_only: bool = False,
272275
use_sharding_from_factory: bool = False,
273276
):
274-
super().__init__()
275-
self.rank = rank
276-
self.world_size = world_size
277-
self.predefined_config = sharding_config
278-
self.simple_shard_only = simple_shard_only
279-
self.use_sharding_from_factory = use_sharding_from_factory
277+
super().__init__(
278+
factory_source=factory_source,
279+
rank=rank,
280+
world_size=world_size,
281+
_predefined_config=sharding_config,
282+
simple_shard_only=simple_shard_only,
283+
use_sharding_from_factory=use_sharding_from_factory,
284+
)
285+
286+
# Pydantic does not support setting private fields directly.
287+
self._predefined_config = sharding_config
288+
# Validate the config after initialization
289+
if self._predefined_config is not None:
290+
self.validate_config()
291+
292+
def validate_config(self) -> bool:
293+
if self.factory_source != FactorySource.HUGGINGFACE:
294+
ad_logger.warning(
295+
"Sharding config is is currently only " + "supported for HuggingFace. Skipping."
296+
)
297+
# invalidate the config
298+
self._predefined_config = None
299+
return False
300+
301+
if not isinstance(self._predefined_config, dict):
302+
ad_logger.warning("Sharding config is not a dictionary. Skipping.")
303+
# invalidate the config
304+
self._predefined_config = None
305+
return False
306+
307+
if "head_dim" not in self._predefined_config:
308+
ad_logger.warning("Sharding config does not contain head_dim. Skipping.")
309+
# invalidate the config
310+
self._predefined_config = None
311+
return False
312+
313+
if "tp_plan" not in self._predefined_config:
314+
ad_logger.warning("Sharding config does not contain tp_plan. Skipping.")
315+
# invalidate the config
316+
self._predefined_config = None
317+
return False
318+
tp_plan = self._predefined_config["tp_plan"]
319+
320+
values = set(tp_plan.values())
321+
allowed_values = {
322+
"colwise", # row split and no collective
323+
"rowwise", # column split and all-reduce
324+
"gather", # simple shard (row + all_gather)
325+
# TODO: remaining values are not supported yet.
326+
# They require hybrid EP+TP and/or SP support.
327+
# "sequence_parallel", # sequence parallelism
328+
# "local_colwise",
329+
# "local_rowwise",
330+
# "local_packed_rowwise",
331+
# "local",
332+
}
333+
if not values.issubset(allowed_values):
334+
ad_logger.warning("Sharding config contains invalid values. Skipping.")
335+
# invalidate the config
336+
self._predefined_config = None
337+
return False
338+
return True
339+
340+
def get_predefined_config(self) -> Dict[str, Any]:
341+
return self._predefined_config
280342

281343

282-
def detect_tp_sharding_from_factory_config(
344+
def detect_sharding_from_factory_config(
283345
gm: GraphModule,
284346
sharding_config: ShardingConfig,
285347
) -> None:
@@ -305,54 +367,30 @@ def detect_tp_sharding_from_factory_config(
305367
# The following constraints are based on
306368
# https://github.com/huggingface/transformers/blob/d8e05951b8efd4880acca9a3f291e8b65841a86d/src/transformers/models/llama4/configuration_llama4.py#L249
307369

308-
if not isinstance(sharding_config.predefined_config, dict):
309-
ad_logger.warning("Sharding config is not a dictionary. Skipping.")
310-
return
311-
312-
if "head_dim" not in sharding_config.predefined_config:
313-
ad_logger.warning("Sharding config does not contain head_dim. Skipping.")
314-
return
315-
head_dim = sharding_config.predefined_config["head_dim"]
316-
317-
if "tp_plan" not in sharding_config.predefined_config:
318-
ad_logger.warning("Sharding config does not contain tp_plan. Skipping.")
319-
return
320-
tp_plan = sharding_config.predefined_config["tp_plan"]
321-
322-
values = set(tp_plan.values())
323-
allowed_values = {
324-
"colwise",
325-
"rowwise",
326-
"sequence_parallel",
327-
"local_colwise",
328-
"local_rowwise",
329-
"local_packed_rowwise",
330-
"local",
331-
"gather",
332-
}
333-
if not values.issubset(allowed_values):
334-
ad_logger.warning("Sharding config contains invalid values. Skipping.")
335-
return
370+
factory_config = sharding_config.get_predefined_config()
371+
head_dim = factory_config["head_dim"]
372+
tp_plan = factory_config["tp_plan"]
336373

337374
rank, world_size = sharding_config.rank, sharding_config.world_size
338375

376+
# If the node is inside the attention module, we need to set min_local_shape to the
377+
# head_dim - otherwise, we would risk splitting the heads into smaller shards.
378+
# TODO: is there a better way to check if we are in attention module?
379+
attn_names = [
380+
"attention",
381+
"Attention",
382+
"attn",
383+
"Attn",
384+
"q_proj",
385+
"k_proj",
386+
"v_proj",
387+
"o_proj",
388+
]
389+
339390
for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op):
340391
# use node's weight name to get the module name
341392
module_name = lin_node.args[1].target
342393

343-
# If the node is inside the attention module, we need to set min_local_shape to the
344-
# head_dim - otherwise, we would risk splitting the heads into smaller shards.
345-
# TODO: is there a better way to check if we are in attention module?
346-
attn_names = [
347-
"attention",
348-
"Attention",
349-
"attn",
350-
"Attn",
351-
"q_proj",
352-
"k_proj",
353-
"v_proj",
354-
"o_proj",
355-
]
356394
if any(attn_name in module_name for attn_name in attn_names):
357395
min_local_shape = head_dim
358396
else:
@@ -424,7 +462,8 @@ def simple_shard_first_n_layers(sharding_config: ShardingConfig, n_layers: int)
424462
# instead of "*".
425463
"""
426464
new_tp_plan = {}
427-
for layer_pattern, config in sharding_config.predefined_config["tp_plan"].items():
465+
factory_config = sharding_config.get_predefined_config()
466+
for layer_pattern, config in factory_config["tp_plan"].items():
428467
if "*" in layer_pattern:
429468
# Create new dict with first n_layers entries first
430469

@@ -434,7 +473,7 @@ def simple_shard_first_n_layers(sharding_config: ShardingConfig, n_layers: int)
434473
# Add the default config after
435474
new_tp_plan[layer_pattern] = config
436475

437-
sharding_config.predefined_config["tp_plan"] = new_tp_plan
476+
sharding_config._predefined_config["tp_plan"] = new_tp_plan
438477

439478

440479
def simple_shard_last_n_layers(sharding_config: ShardingConfig, n_layers: int) -> None:
@@ -446,8 +485,9 @@ def simple_shard_last_n_layers(sharding_config: ShardingConfig, n_layers: int) -
446485
# instead of "*".
447486
"""
448487
new_tp_plan = {}
449-
num_layers = sharding_config.predefined_config["num_hidden_layers"]
450-
for layer_pattern, config in sharding_config.predefined_config["tp_plan"].items():
488+
factory_config = sharding_config.get_predefined_config()
489+
num_layers = factory_config["num_hidden_layers"]
490+
for layer_pattern, config in factory_config["tp_plan"].items():
451491
if "*" in layer_pattern:
452492
# Create new dict with first n_layers entries first
453493

@@ -456,18 +496,18 @@ def simple_shard_last_n_layers(sharding_config: ShardingConfig, n_layers: int) -
456496

457497
# Add the default config after
458498
new_tp_plan[layer_pattern] = config
459-
sharding_config.predefined_config["tp_plan"] = new_tp_plan
499+
sharding_config._predefined_config["tp_plan"] = new_tp_plan
460500

461501

462502
def simple_shard_attention_layers(sharding_config: ShardingConfig) -> None:
463503
"""
464504
If any key in tp_plan contains "attention", replace it with "gather"
465505
"""
466-
for layer_pattern, config in sharding_config.predefined_config["tp_plan"].items():
506+
for layer_pattern, config in sharding_config._predefined_config["tp_plan"].items():
467507
if any(
468508
attn_name in layer_pattern for attn_name in ["attention", "Attention", "attn", "Attn"]
469509
):
470-
sharding_config.predefined_config["tp_plan"][layer_pattern] = "gather"
510+
sharding_config._predefined_config["tp_plan"][layer_pattern] = "gather"
471511

472512

473513
def sharding_transform_executor(gm: GraphModule, sharding_config: ShardingConfig) -> None:
@@ -687,6 +727,26 @@ def _append_simple_shard(
687727
sharding_config.tp_transforms.extend(tp_shards)
688728

689729

730+
def detect_sharding(gm: GraphModule, sharding_config: ShardingConfig) -> None:
731+
if (
732+
sharding_config.use_sharding_from_factory
733+
and sharding_config.get_predefined_config() is not None
734+
):
735+
ad_logger.info("Applying sharding from config")
736+
detect_sharding_from_factory_config(gm, sharding_config)
737+
return
738+
739+
ad_logger.info("Running autodeploy sharding heuristics")
740+
# run TP sharding across ranks
741+
detect_column_row_shard(gm, sharding_config)
742+
743+
# run EP sharding across ranks
744+
detect_ep_shard(gm, sharding_config)
745+
746+
# run BMM sharding across ranks
747+
detect_dp_bmm_shard(gm, sharding_config)
748+
749+
690750
def detect_column_row_shard(
691751
gm: GraphModule,
692752
sharding_config: ShardingConfig,
@@ -716,11 +776,6 @@ def detect_column_row_shard(
716776

717777
assert isinstance(gm, GraphModule), "Expecting GraphModule"
718778

719-
if sharding_config.use_sharding_from_factory and sharding_config.predefined_config is not None:
720-
ad_logger.info("Using TP sharding from config")
721-
detect_tp_sharding_from_factory_config(gm, sharding_config)
722-
return
723-
724779
ad_logger.info("Running TP sharding detection")
725780

726781
# find boundary nodes of regions we want to shard

tensorrt_llm/_torch/auto_deploy/transformations/transform.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
from ._graph import canonicalize_graph, lift_to_meta, move_to_device
1717
from .library import (
1818
ShardingConfig,
19-
detect_column_row_shard,
20-
detect_dp_bmm_shard,
21-
detect_ep_shard,
19+
detect_sharding,
2220
eliminate_redundant_transposes,
2321
fuse_allreduce_residual_rmsnorm,
2422
fuse_collectives,
@@ -115,21 +113,15 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module:
115113
optimize_rope(egm)
116114

117115
sharding_config = ShardingConfig(
116+
self.factory.get_model_source(),
118117
local_rank,
119118
world_size,
120119
self.factory.get_sharding_config(),
121120
self.ad_config.simple_shard_only,
122121
self.ad_config.use_sharding_from_factory,
123122
)
124123

125-
# run TP sharding across ranks
126-
detect_column_row_shard(egm, sharding_config)
127-
128-
# run EP sharding across ranks
129-
detect_ep_shard(egm, sharding_config)
130-
131-
# run BMM sharding across ranks
132-
detect_dp_bmm_shard(egm, sharding_config)
124+
detect_sharding(egm, sharding_config)
133125

134126
sharding_transform_executor(egm, sharding_config)
135127

0 commit comments

Comments
 (0)