-
Notifications
You must be signed in to change notification settings - Fork 307
Expand file tree
/
Copy pathcustom_parallel.py
More file actions
62 lines (55 loc) · 2.57 KB
/
custom_parallel.py
File metadata and controls
62 lines (55 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.placement_types import Replicate, Shard
custom_parallel_plan = {
"model.embed_tokens": RowwiseParallel(input_layouts=Replicate()),
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(),
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
}
"""
Note on numerical stability:
- Default plans that keep attention output proj and mlp downproj RowwiseParallel are numerically
unstable and tend to increase with larger TP (e.g., TP >= 4).
Enable this custom plan via:
- policy.dtensor_cfg.custom_parallel_plan=examples.custom_parallel.qwen_model_tp_plan_stable
Based on https://github.com/NVIDIA-NeMo/Automodel/blob/d79ccb94b0eca94a4c479313db2f9eee80db0139/nemo_automodel/components/distributed/optimized_tp_plans.py#L205-L217
"""
qwen_model_tp_plan_stable = {
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
"model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(),
),
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
"model.layers.*.self_attn.o_proj": ColwiseParallel(
input_layouts=Shard(-1),
output_layouts=Replicate(),
use_local_output=True,
),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": ColwiseParallel(
input_layouts=Shard(-1),
output_layouts=Replicate(),
use_local_output=True,
),
}