Skip to content

Commit 79c2c52

Browse files
authored
deepgemm pre-compile tool support mixed parallel (#4282)
1 parent 5c6e859 commit 79c2c52

File tree

1 file changed

+26
-55
lines changed

1 file changed

+26
-55
lines changed

tools/deep_gemm_pre-compile/generate_config.py

Lines changed: 26 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -41,70 +41,41 @@ def generate_kn_pairs(args, model_cfg: dict) -> Tuple[List, List, List]:
4141
gemm_kn_pairs = []
4242
grouped_gemm_contiguous_kn_pairs = []
4343
grouped_gemm_masked_kn_pairs = []
44-
if tp_size > 1 and ep_size == 1:
45-
logger.debug("Generating kn pairs for tensor parallel.")
46-
# Dense normal gemm
47-
gemm_kn_pairs.extend(
48-
[
49-
[int(intermediate_size / tp_size), hidden_size],
50-
[hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2) / tp_size)],
51-
[hidden_size, int(intermediate_size * 2 / tp_size)],
52-
[int(hidden_size / tp_size), hidden_size],
53-
]
54-
)
44+
logger.debug("Generating kn pairs for tensor parallel.")
45+
# Dense normal gemm
46+
gemm_kn_pairs.extend(
47+
[
48+
[int(intermediate_size / tp_size), hidden_size],
49+
[hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2) / tp_size)],
50+
[hidden_size, int(intermediate_size * 2 / tp_size)],
51+
[int(hidden_size / tp_size), hidden_size],
52+
]
53+
)
5554

56-
# Moe grouped gemm contiguous
57-
grouped_gemm_contiguous_kn_pairs.extend(
58-
[
59-
[int(moe_intermediate_size / tp_size), hidden_size],
60-
[hidden_size, int(moe_intermediate_size * 2 / tp_size)],
61-
]
62-
)
63-
if has_shared_experts:
64-
logger.debug("Generating kn pairs for models with shared experts.")
65-
gemm_kn_pairs.extend(
66-
[
67-
[hidden_size, int(moe_intermediate_size * 4 / tp_size)],
68-
[int(moe_intermediate_size * 2 / tp_size), hidden_size],
69-
]
70-
)
71-
elif tp_size == 1 and ep_size > 1:
72-
logger.debug("Generating kn pairs for expert parallel.")
73-
# Dense normal gemm
74-
gemm_kn_pairs.extend(
75-
[
76-
[intermediate_size, hidden_size],
77-
[hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2))],
78-
[hidden_size, int(intermediate_size * 2)],
79-
[hidden_size, hidden_size],
80-
]
81-
)
82-
# Moe grouped gemm contiguous
83-
grouped_gemm_contiguous_kn_pairs.extend(
55+
# Moe grouped gemm contiguous
56+
grouped_gemm_contiguous_kn_pairs.extend(
57+
[
58+
[int(moe_intermediate_size / tp_size), hidden_size],
59+
[hidden_size, int(moe_intermediate_size * 2 / tp_size)],
60+
]
61+
)
62+
63+
if ep_size > 1:
64+
# Moe grouped gemm masked
65+
grouped_gemm_masked_kn_pairs.extend(
8466
[
8567
[moe_intermediate_size, hidden_size],
8668
[hidden_size, int(moe_intermediate_size * 2)],
8769
]
8870
)
89-
# Moe grouped gemm masked
90-
grouped_gemm_masked_kn_pairs.extend(
71+
if has_shared_experts:
72+
logger.debug("Generating kn pairs for models with shared experts.")
73+
gemm_kn_pairs.extend(
9174
[
92-
[moe_intermediate_size, hidden_size],
93-
[hidden_size, int(moe_intermediate_size * 2)],
75+
[hidden_size, int(moe_intermediate_size * 4 / tp_size)],
76+
[int(moe_intermediate_size * 2 / tp_size), hidden_size],
9477
]
9578
)
96-
if has_shared_experts:
97-
logger.debug("Generating kn pairs for models with shared experts.")
98-
gemm_kn_pairs.extend(
99-
[
100-
[hidden_size, int(moe_intermediate_size * 4)],
101-
[int(moe_intermediate_size * 2), hidden_size],
102-
]
103-
)
104-
elif tp_size > 1 and ep_size > 1:
105-
raise ValueError("Not supported to enable EP and TP at the same time for now.")
106-
else:
107-
raise ValueError("Please check the tensor parallel size and expert parallel size.")
10879

10980
return (
11081
gemm_kn_pairs,

0 commit comments

Comments
 (0)