Skip to content

Commit 2eb644e

Browse files
sharding config seems to work
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 9662d81 commit 2eb644e

File tree

5 files changed

+300
-28
lines changed

5 files changed

+300
-28
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
157157
"If False, auto-detect and use column+row (all_reduce) sharding when possible.",
158158
)
159159

160+
use_sharding_from_config: bool = Field(
161+
default=True,
162+
description="If True, use sharding from the model config (if present). "
163+
"If False, run heuristics to detect sharding.",
164+
)
165+
160166
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
161167
Field(
162168
default="torch-compile",

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,43 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
174174
if hasattr(model, "post_init"):
175175
model.post_init()
176176

177+
# if present, initialize sharding config. We need head_dim for colwise sharding.
178+
self._sharding_config = {}
179+
self._sharding_config["head_dim"] = 1
180+
if hasattr(model_config, "base_model_tp_plan"):
181+
self._sharding_config["tp_plan"] = model_config.base_model_tp_plan
182+
if hasattr(model_config, "head_dim"):
183+
self._sharding_config["head_dim"] = model_config.head_dim
184+
if hasattr(model_config, "num_hidden_layers"):
185+
self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers
186+
# if it is a multi-modal factory, overwrite the sharding config with the
187+
# dedicated sub-configs
188+
if hasattr(model_config, "sub_configs") and len(model_config.sub_configs) > 0:
189+
# for image-text-to-text models, we only support sharding for the text sub-config
190+
if isinstance(self, AutoModelForImageTextToTextFactory):
191+
text_config = model_config.sub_configs["text_config"]
192+
# if text_config is a class, instantiate it
193+
if isinstance(text_config, type):
194+
text_config = text_config()
195+
if hasattr(text_config, "base_model_tp_plan"):
196+
self._sharding_config["tp_plan"] = text_config.base_model_tp_plan
197+
if hasattr(text_config, "head_dim"):
198+
self._sharding_config["head_dim"] = text_config.head_dim
199+
if hasattr(text_config, "num_hidden_layers"):
200+
self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers
201+
else:
202+
# TODO: support sharding for other multi-modal models
203+
pass
204+
177205
# patch forward method
178206
model.forward = types.MethodType(self._simple_forward, model)
179207

180208
model.eval()
181209
return model
182210

211+
def get_sharding_config(self):
212+
return self._sharding_config or {}
213+
183214
def get_quant_config(self) -> Dict:
184215
return self._quant_config or {}
185216

tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py

Lines changed: 193 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818

1919
import math
2020
import operator
21+
import re
2122
from abc import ABC, abstractmethod
2223
from collections import defaultdict
2324
from enum import IntEnum
2425
from functools import partial
25-
from typing import Callable, DefaultDict, Dict, List, Literal, Optional, Set
26+
from typing import Any, Callable, DefaultDict, Dict, List, Literal, Optional, Set
2627

2728
import torch
2829
import torch.nn as nn
@@ -32,6 +33,7 @@
3233
from ...utils.logger import ad_logger
3334
from ...utils.node_utils import (
3435
extract_param_names_from_lin_node,
36+
filtered_nodes,
3537
identify_regions_between_residuals,
3638
is_linear_op,
3739
is_op,
@@ -248,10 +250,200 @@ def apply(self, gm: GraphModule, node: Node) -> None:
248250
class ShardingConfig(BaseModel):
249251
"""Configuration for sharding the model."""
250252

253+
rank: int = 0
254+
world_size: int = 1
255+
predefined_config: Dict[str, Any] = None
251256
tp_transforms: List[TPShardingInfo] = Field(default_factory=list)
252257
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)
253258
ep_transforms: List[EPShardingInfo] = Field(default_factory=list)
254259

260+
def __init__(self, rank: int, world_size: int, sharding_config: Dict[str, Any] = None):
261+
super().__init__()
262+
self.rank = rank
263+
self.world_size = world_size
264+
self.predefined_config = sharding_config
265+
266+
def create_sharding_from_config(
267+
self, gm: GraphModule, sharding_config: Dict[str, Any] = None
268+
) -> None:
269+
"""
270+
Create sharding transformations from the predefined config.
271+
TODO: currently, it applies only to TP sharding.
272+
Args:
273+
gm: Graph module to apply transformations to
274+
sharding_config: Predefined sharding configuration
275+
"""
276+
if sharding_config is not None:
277+
self.predefined_config = sharding_config
278+
279+
# check if config is valid.
280+
# 1. it is a Dict[str, str]
281+
# 2. the keys are of format "module.submodule.subsubmodule..."
282+
# 3. the wildcard "*" is allowed in the keys
283+
# 4. the allowed values are:
284+
# - "colwise"
285+
# - "rowwise"
286+
# - "sequence_parallel"
287+
# - "local_colwise"
288+
# - "local_rowwise"
289+
# - "local"
290+
# - "gather"
291+
# The following constraints are based on
292+
# https://github.com/huggingface/transformers/blob/d8e05951b8efd4880acca9a3f291e8b65841a86d/src/transformers/models/llama4/configuration_llama4.py#L249
293+
294+
if not isinstance(self.predefined_config, dict):
295+
ad_logger.warning("Sharding config is not a dictionary. Skipping.")
296+
return
297+
298+
if "head_dim" not in self.predefined_config:
299+
ad_logger.warning("Sharding config does not contain head_dim. Skipping.")
300+
return
301+
head_dim = self.predefined_config["head_dim"]
302+
303+
if "tp_plan" not in self.predefined_config:
304+
ad_logger.warning("Sharding config does not contain tp_plan. Skipping.")
305+
return
306+
tp_plan = self.predefined_config["tp_plan"]
307+
308+
values = set(tp_plan.values())
309+
allowed_values = {
310+
"colwise",
311+
"rowwise",
312+
"sequence_parallel",
313+
"local_colwise",
314+
"local_rowwise",
315+
"local_packed_rowwise",
316+
"local",
317+
"gather",
318+
}
319+
if not values.issubset(allowed_values):
320+
ad_logger.warning("Sharding config contains invalid values. Skipping.")
321+
return
322+
323+
for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op):
324+
module_name = list(lin_node.meta["nn_module_stack"].keys())[-1]
325+
# use regex to find if module_name matches any of the keys in sharding_config
326+
for key in tp_plan.keys():
327+
pattern_string = "*" + key + "*"
328+
# convert it to regex. Escape dots, replace * with .*
329+
# WARNING! A very hacky solution! First, we substitute * with unlikely character, e.g. @
330+
# Then we escape dots, and finally we replace @ with .*
331+
pattern_string = pattern_string.replace("*", "@")
332+
pattern_regex = re.escape(pattern_string).replace("@", ".*")
333+
if re.match(pattern_regex, module_name):
334+
# we have a match. Get the config for this layer
335+
config = tp_plan[key]
336+
# TODO: @lucaslie: this is SUPER CONFUSING!
337+
# HF config uses "column" and "row" as-if Y = X @ W, so you have
338+
# all-gather after column, and all-reduce after row.
339+
# But since we assume Y = W @ X^T, we have a swapped column and row split.
340+
if config == "colwise":
341+
# if we are doing colwise split, we need to check if we are in
342+
# attention module. If so, we need to set min_local_shape to the
343+
# head_dim - otherwise, we would risk splitting the heads into smaller shards.
344+
# TODO: is there a better way to check if we are in attention module?
345+
attn_names = ["attention", "Attention", "attn", "Attn"]
346+
if any(attn_name in module_name for attn_name in attn_names):
347+
min_local_shape = head_dim
348+
else:
349+
min_local_shape = 1
350+
self.tp_transforms.append(
351+
TPShardingInfo(
352+
target_node=lin_node.name,
353+
split_dim=SplitDimension.ROW,
354+
rank=self.rank,
355+
world_size=self.world_size,
356+
dist_op=None,
357+
min_local_shape=min_local_shape,
358+
)
359+
)
360+
elif config == "rowwise":
361+
self.tp_transforms.append(
362+
TPShardingInfo(
363+
target_node=lin_node.name,
364+
split_dim=SplitDimension.COLUMN,
365+
rank=self.rank,
366+
world_size=self.world_size,
367+
dist_op="all_reduce",
368+
min_local_shape=1,
369+
)
370+
)
371+
elif "sequence" in config:
372+
# TODO: Sequence parallelism is not supported yet.
373+
ad_logger.warning("Sequence parallelism is not supported yet. Skipping.")
374+
elif "local" in config:
375+
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
376+
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
377+
elif "gather" in config:
378+
# Simple shard (row + all_gather)
379+
self.tp_transforms.append(
380+
TPShardingInfo(
381+
target_node=lin_node.name,
382+
split_dim=SplitDimension.ROW,
383+
rank=self.rank,
384+
world_size=self.world_size,
385+
dist_op="all_gather",
386+
min_local_shape=1,
387+
)
388+
)
389+
else:
390+
ad_logger.warning("Invalid sharding config. Skipping.")
391+
# after successful match, break the loop
392+
break
393+
394+
def simple_shard_first_n_layers(self, n_layers: int) -> None:
395+
"""
396+
Simple shard the first n layers.
397+
1. Take the existing config self.predefined_config,
398+
2. Search for lines with wildcard "*",
399+
3. Prepend to the top of the config list the same lines with "0, 1, ..., n_layers-1"
400+
# instead of "*".
401+
"""
402+
new_tp_plan = {}
403+
for layer_pattern, config in self.predefined_config["tp_plan"].items():
404+
if "*" in layer_pattern:
405+
# Create new dict with first n_layers entries first
406+
407+
for i in range(n_layers):
408+
new_tp_plan[layer_pattern.replace("*", str(i))] = "gather"
409+
410+
# Add the default config after
411+
new_tp_plan[layer_pattern] = config
412+
413+
self.predefined_config["tp_plan"] = new_tp_plan
414+
415+
def simple_shard_last_n_layers(self, n_layers: int) -> None:
416+
"""
417+
Simple shard the last n layers.
418+
1. Take the existing config self.predefined_config,
419+
2. Search for lines with wildcard "*",
420+
3. Prepend to the top of the config list the same lines with "0, 1, ..., n_layers-1"
421+
# instead of "*".
422+
"""
423+
new_tp_plan = {}
424+
num_layers = self.predefined_config["num_hidden_layers"]
425+
for layer_pattern, config in self.predefined_config["tp_plan"].items():
426+
if "*" in layer_pattern:
427+
# Create new dict with first n_layers entries first
428+
429+
for i in range(num_layers - n_layers, num_layers):
430+
new_tp_plan[layer_pattern.replace("*", str(i))] = "gather"
431+
432+
# Add the default config after
433+
new_tp_plan[layer_pattern] = config
434+
self.predefined_config["tp_plan"] = new_tp_plan
435+
436+
def simple_shard_attention_layers(self) -> None:
437+
"""
438+
If any key in tp_plan contains "attention", replace it with "gather"
439+
"""
440+
for layer_pattern, config in self.predefined_config["tp_plan"].items():
441+
if any(
442+
attn_name in layer_pattern
443+
for attn_name in ["attention", "Attention", "attn", "Attn"]
444+
):
445+
self.predefined_config["tp_plan"][layer_pattern] = "gather"
446+
255447

256448
def sharding_transform_executor(gm: GraphModule, sharding_config: ShardingConfig) -> None:
257449
"""Apply transformations to the graph module.

tensorrt_llm/_torch/auto_deploy/transformations/transform.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,34 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module:
114114
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
115115
optimize_rope(egm)
116116

117-
# TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config.
118-
sharding_config = ShardingConfig()
119-
120-
# run TP sharding across ranks
121-
detect_column_row_shard(
122-
egm, local_rank, world_size, sharding_config, self.ad_config.simple_shard_only
123-
)
124-
125-
# run EP sharding across ranks
126-
detect_ep_shard(egm, local_rank, world_size, sharding_config)
127-
128-
# run BMM sharding across ranks
129-
detect_dp_bmm_shard(egm, local_rank, world_size, sharding_config)
117+
sharding_config = ShardingConfig(local_rank, world_size, self.factory.get_sharding_config())
118+
self.ad_config.use_sharding_from_config = False
119+
if (
120+
self.ad_config.use_sharding_from_config
121+
and sharding_config.predefined_config is not None
122+
):
123+
ad_logger.info("\n\nUsing TP sharding from config\n")
124+
# sharding_config.simple_shard_attention_layers()
125+
sharding_config.create_sharding_from_config(egm)
126+
else:
127+
ad_logger.info("\n\nRunning TP sharding detection\n")
128+
# run TP sharding across ranks
129+
detect_column_row_shard(
130+
egm, local_rank, world_size, sharding_config, self.ad_config.simple_shard_only
131+
)
132+
133+
# run EP sharding across ranks
134+
detect_ep_shard(egm, local_rank, world_size, sharding_config)
135+
136+
# run BMM sharding across ranks
137+
detect_dp_bmm_shard(egm, local_rank, world_size, sharding_config)
138+
139+
# print detected transformations
140+
ad_logger.info("\n\nTP sharding:")
141+
for tp_transform in sharding_config.tp_transforms:
142+
ad_logger.info(
143+
f"{tp_transform.target_node} {tp_transform.split_dim} {tp_transform.dist_op}"
144+
)
130145

131146
sharding_transform_executor(egm, sharding_config)
132147

0 commit comments

Comments
 (0)