Skip to content

Commit 9afc419

Browse files
authored
feat: more numerically stable qwen custom plan (NVIDIA-NeMo#1235)
Signed-off-by: Terry Kong <terryk@nvidia.com>
1 parent dccdf79 commit 9afc419

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

examples/custom_parallel.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,37 @@
2626
"model.layers.*.mlp.down_proj": RowwiseParallel(),
2727
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
2828
}
29+
30+
"""
31+
Note on numerical stability:
32+
33+
- Default plans that keep attention output proj and mlp downproj RowwiseParallel are numerically
34+
unstable and tend to increase with larger TP (e.g., TP >= 4).
35+
36+
Enable this custom plan via:
37+
38+
- policy.dtensor_cfg.custom_parallel_plan=examples.custom_parallel.qwen_model_tp_plan_stable
39+
40+
Based on https://github.com/NVIDIA-NeMo/Automodel/blob/d79ccb94b0eca94a4c479313db2f9eee80db0139/nemo_automodel/components/distributed/optimized_tp_plans.py#L205-L217
41+
"""
42+
qwen_model_tp_plan_stable = {
43+
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
44+
"model.embed_tokens": RowwiseParallel(
45+
input_layouts=Replicate(),
46+
),
47+
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
48+
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
49+
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
50+
"model.layers.*.self_attn.o_proj": ColwiseParallel(
51+
input_layouts=Shard(-1),
52+
output_layouts=Replicate(),
53+
use_local_output=True,
54+
),
55+
"model.layers.*.mlp.up_proj": ColwiseParallel(),
56+
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
57+
"model.layers.*.mlp.down_proj": ColwiseParallel(
58+
input_layouts=Shard(-1),
59+
output_layouts=Replicate(),
60+
use_local_output=True,
61+
),
62+
}

0 commit comments

Comments
 (0)