You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/torch/auto_deploy/advanced/expert_configurations.md
+79Lines changed: 79 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -153,6 +153,85 @@ python build_and_run_ad.py \
153
153
--args.world-size=8 # CLI override beats both YAML configs
154
154
```
155
155
156
+
## Sharding configuration
157
+
158
+
The `detect_sharding` transform automatically detects and applies sharding strategies to the model. It supports multiple sharding sources and dimensions, allowing flexible configuration for different model architectures and parallelism strategies.
159
+
160
+
### Configuration Parameters
161
+
162
+
The `detect_sharding` transform accepts the following configuration parameters:
163
+
164
+
#### `simple_shard_only` (bool, default: `false`)
165
+
166
+
When set to `true`, forces simple sharding (row-wise sharding with all-gather) for all linear layers, bypassing more sophisticated column/row sharding strategies. This is useful when you want a uniform sharding approach across all layers or when debugging sharding issues.
Specifies the priority order of sharding sources. The order matters: if multiple sources try to apply sharding to the same layer, only the first one in the list will be applied. The available sources are:
171
+
172
+
- **`'manual'`**: Uses manually provided sharding configuration via `manual_config` parameter
173
+
- **`'factory'`**: Uses factory-provided sharding configuration (e.g., from HuggingFace model configs)
174
+
- **`'heuristic'`**: Uses automatic heuristic-based sharding detection based on layer patterns
175
+
176
+
Example: If both `manual` and `heuristic` try to apply sharding to layer L, only the `manual` transformation will be applied since it appears first in the list.
When `true`, allows partial sharding configurations where not all layers need to be specified in the manual or factory config. Layers not explicitly configured will be handled by heuristic sharding or left unsharded. When `false`, the configuration must specify all layers or it will be invalidated and skipped.
Whether shape propagation is required before applying this transform. Shape propagation enables the transform to make informed decisions about sharding strategies based on tensor dimensions.
196
+
197
+
### Manual TP Sharding Configuration
198
+
199
+
For advanced users, you can provide a manual sharding configuration. An example of such setting:
200
+
201
+
```yaml
202
+
args:
203
+
transforms:
204
+
detect_sharding:
205
+
manual_config:
206
+
head_dim: 128
207
+
tp_plan:
208
+
# mamba SSM layers
209
+
in_proj: mamba
210
+
out_proj: rowwise
211
+
# attention layers
212
+
q_proj: colwise
213
+
k_proj: colwise
214
+
v_proj: colwise
215
+
o_proj: rowwise
216
+
# NOTE: for performance reason, consider not sharding the following
217
+
# layers at all. Commenting out the following layers will replicate
218
+
# them across ranks.
219
+
# MLP and shared experts in MoE layers
220
+
gate_proj: colwise
221
+
up_proj: colwise
222
+
down_proj: rowwise
223
+
# MoLE: latent projections: simple shard
224
+
fc1_latent_proj: gather
225
+
fc2_latent_proj: gather
226
+
```
227
+
228
+
The `tp_plan` dictionary maps layer names (using module paths with wildcard `*` support) to sharding strategies:
229
+
230
+
- **`colwise`**: Column-wise sharding (splits the weight matrix along columns)
231
+
- **`rowwise`**: Row-wise sharding (splits the weight matrix along rows)
232
+
- **`mamba`**: Specialized sharding for Mamba SSM layers
233
+
- **`gather`**: Simple shard with row-wise sharding and all-gather operation
234
+
156
235
## Built-in Default Configuration
157
236
158
237
Both `AutoDeployConfig` and `LlmArgs` classes automatically load a built-in `default.yaml` configuration file that provides defaults for the AutoDeploy inference optimizer pipeline. This file is specified in the `_get_config_dict()` function in `tensorrt_llm._torch.auto_deploy.llm_args` and defines default transform configurations for graph optimization stages.
0 commit comments