Skip to content

Commit cff54fc

Browse files
[NVIDIA#8948][feat] Support custom sharding config (NVIDIA#9143)
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent bc355ea commit cff54fc

File tree

9 files changed

+249
-160
lines changed

9 files changed

+249
-160
lines changed

docs/source/torch/auto_deploy/advanced/expert_configurations.md

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,85 @@ python build_and_run_ad.py \
153153
--args.world-size=8 # CLI override beats both YAML configs
154154
```
155155

156+
## Sharding configuration
157+
158+
The `detect_sharding` transform automatically detects and applies sharding strategies to the model. It supports multiple sharding sources and dimensions, allowing flexible configuration for different model architectures and parallelism strategies.
159+
160+
### Configuration Parameters
161+
162+
The `detect_sharding` transform accepts the following configuration parameters:
163+
164+
#### `simple_shard_only` (bool, default: `false`)
165+
166+
When set to `true`, forces simple sharding (row-wise sharding with all-gather) for all linear layers, bypassing more sophisticated column/row sharding strategies. This is useful when you want a uniform sharding approach across all layers or when debugging sharding issues.
167+
168+
#### `sharding_source` (list, default: `['manual', 'factory', 'heuristic']`)
169+
170+
Specifies the priority order of sharding sources. The order matters: if multiple sources try to apply sharding to the same layer, only the first one in the list will be applied. The available sources are:
171+
172+
- **`'manual'`**: Uses manually provided sharding configuration via `manual_config` parameter
173+
- **`'factory'`**: Uses factory-provided sharding configuration (e.g., from HuggingFace model configs)
174+
- **`'heuristic'`**: Uses automatic heuristic-based sharding detection based on layer patterns
175+
176+
Example: If both `manual` and `heuristic` try to apply sharding to layer L, only the `manual` transformation will be applied since it appears first in the list.
177+
178+
#### `support_partial_config` (bool, default: `true`)
179+
180+
When `true`, allows partial sharding configurations where not all layers need to be specified in the manual or factory config. Layers not explicitly configured will be handled by heuristic sharding or left unsharded. When `false`, the configuration must specify all layers or it will be invalidated and skipped.
181+
182+
#### `sharding_dims` (list, default: `['tp', 'ep', 'bmm']`)
183+
184+
Specifies which sharding dimensions to apply during heuristic sharding. The available dimensions are:
185+
186+
- **`'tp'`**: Tensor parallelism - applies column/row sharding for standard transformer layers
187+
- **`'ep'`**: Expert parallelism - shards experts across ranks for Mixture-of-Experts (MoE) models
188+
- **`'bmm'`**: Batch matrix multiplication sharding - shards batch matrix multiplication operations
189+
- **`'ssm'`**: State space model sharding - applies specialized sharding for Mamba/SSM layers
190+
191+
You can enable multiple dimensions simultaneously. For example, `['tp', 'ep']` will apply both tensor parallelism and expert parallelism.
192+
193+
#### `requires_shape_prop` (bool, default: `true`)
194+
195+
Whether shape propagation is required before applying this transform. Shape propagation enables the transform to make informed decisions about sharding strategies based on tensor dimensions.
196+
197+
### Manual TP Sharding Configuration
198+
199+
For advanced users, you can provide a manual sharding configuration. An example of such setting:
200+
201+
```yaml
202+
args:
203+
transforms:
204+
detect_sharding:
205+
manual_config:
206+
head_dim: 128
207+
tp_plan:
208+
# mamba SSM layers
209+
in_proj: mamba
210+
out_proj: rowwise
211+
# attention layers
212+
q_proj: colwise
213+
k_proj: colwise
214+
v_proj: colwise
215+
o_proj: rowwise
216+
# NOTE: for performance reason, consider not sharding the following
217+
# layers at all. Commenting out the following layers will replicate
218+
# them across ranks.
219+
# MLP and shared experts in MoE layers
220+
gate_proj: colwise
221+
up_proj: colwise
222+
down_proj: rowwise
223+
# MoLE: latent projections: simple shard
224+
fc1_latent_proj: gather
225+
fc2_latent_proj: gather
226+
```
227+
228+
The `tp_plan` dictionary maps layer names (using module paths with wildcard `*` support) to sharding strategies:
229+
230+
- **`colwise`**: Column-wise sharding (splits the weight matrix along columns)
231+
- **`rowwise`**: Row-wise sharding (splits the weight matrix along rows)
232+
- **`mamba`**: Specialized sharding for Mamba SSM layers
233+
- **`gather`**: Simple shard with row-wise sharding and all-gather operation
234+
156235
## Built-in Default Configuration
157236

158237
Both `AutoDeployConfig` and `LlmArgs` classes automatically load a built-in `default.yaml` configuration file that provides defaults for the AutoDeploy inference optimizer pipeline. This file is specified in the `_get_config_dict()` function in `tensorrt_llm._torch.auto_deploy.llm_args` and defines default transform configurations for graph optimization stages.

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ transforms:
7676
detect_sharding:
7777
stage: sharding
7878
simple_shard_only: false
79-
sharding_source: ['factory','heuristic']
79+
sharding_source: ['manual', 'factory', 'heuristic']
8080
support_partial_config: true
8181
sharding_dims: ['tp', 'ep', 'bmm']
8282
allreduce_strategy: 'AUTO'

tensorrt_llm/_torch/auto_deploy/transform/interface.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
run_shape_prop,
2525
)
2626
from ..utils.logger import ad_logger
27-
from ..utils.sharding_utils import ShardingConfig
27+
from ..utils.sharding_utils import ShardingTransformContainer
2828

2929

3030
class TransformError(Exception):
@@ -61,7 +61,9 @@ def __lt__(self, other):
6161
class SharedConfig(BaseModel):
6262
"""Global config shared between multiple transforms in the inference optimizer."""
6363

64-
sharding_config: ShardingConfig = Field(default_factory=ShardingConfig)
64+
sharding_transform_container: ShardingTransformContainer = Field(
65+
default_factory=ShardingTransformContainer
66+
)
6567
local_rank: int = Field(default=0)
6668
world_size: int = Field(default=1)
6769

0 commit comments

Comments
 (0)