@@ -41,70 +41,41 @@ def generate_kn_pairs(args, model_cfg: dict) -> Tuple[List, List, List]:
41
41
gemm_kn_pairs = []
42
42
grouped_gemm_contiguous_kn_pairs = []
43
43
grouped_gemm_masked_kn_pairs = []
44
- if tp_size > 1 and ep_size == 1 :
45
- logger .debug ("Generating kn pairs for tensor parallel." )
46
- # Dense normal gemm
47
- gemm_kn_pairs .extend (
48
- [
49
- [int (intermediate_size / tp_size ), hidden_size ],
50
- [hidden_size , int (head_dim * (num_attention_heads + num_key_value_heads * 2 ) / tp_size )],
51
- [hidden_size , int (intermediate_size * 2 / tp_size )],
52
- [int (hidden_size / tp_size ), hidden_size ],
53
- ]
54
- )
44
+ logger .debug ("Generating kn pairs for tensor parallel." )
45
+ # Dense normal gemm
46
+ gemm_kn_pairs .extend (
47
+ [
48
+ [int (intermediate_size / tp_size ), hidden_size ],
49
+ [hidden_size , int (head_dim * (num_attention_heads + num_key_value_heads * 2 ) / tp_size )],
50
+ [hidden_size , int (intermediate_size * 2 / tp_size )],
51
+ [int (hidden_size / tp_size ), hidden_size ],
52
+ ]
53
+ )
55
54
56
- # Moe grouped gemm contiguous
57
- grouped_gemm_contiguous_kn_pairs .extend (
58
- [
59
- [int (moe_intermediate_size / tp_size ), hidden_size ],
60
- [hidden_size , int (moe_intermediate_size * 2 / tp_size )],
61
- ]
62
- )
63
- if has_shared_experts :
64
- logger .debug ("Generating kn pairs for models with shared experts." )
65
- gemm_kn_pairs .extend (
66
- [
67
- [hidden_size , int (moe_intermediate_size * 4 / tp_size )],
68
- [int (moe_intermediate_size * 2 / tp_size ), hidden_size ],
69
- ]
70
- )
71
- elif tp_size == 1 and ep_size > 1 :
72
- logger .debug ("Generating kn pairs for expert parallel." )
73
- # Dense normal gemm
74
- gemm_kn_pairs .extend (
75
- [
76
- [intermediate_size , hidden_size ],
77
- [hidden_size , int (head_dim * (num_attention_heads + num_key_value_heads * 2 ))],
78
- [hidden_size , int (intermediate_size * 2 )],
79
- [hidden_size , hidden_size ],
80
- ]
81
- )
82
- # Moe grouped gemm contiguous
83
- grouped_gemm_contiguous_kn_pairs .extend (
55
+ # Moe grouped gemm contiguous
56
+ grouped_gemm_contiguous_kn_pairs .extend (
57
+ [
58
+ [int (moe_intermediate_size / tp_size ), hidden_size ],
59
+ [hidden_size , int (moe_intermediate_size * 2 / tp_size )],
60
+ ]
61
+ )
62
+
63
+ if ep_size > 1 :
64
+ # Moe grouped gemm masked
65
+ grouped_gemm_masked_kn_pairs .extend (
84
66
[
85
67
[moe_intermediate_size , hidden_size ],
86
68
[hidden_size , int (moe_intermediate_size * 2 )],
87
69
]
88
70
)
89
- # Moe grouped gemm masked
90
- grouped_gemm_masked_kn_pairs .extend (
71
+ if has_shared_experts :
72
+ logger .debug ("Generating kn pairs for models with shared experts." )
73
+ gemm_kn_pairs .extend (
91
74
[
92
- [moe_intermediate_size , hidden_size ],
93
- [hidden_size , int (moe_intermediate_size * 2 ) ],
75
+ [hidden_size , int ( moe_intermediate_size * 4 / tp_size ) ],
76
+ [int (moe_intermediate_size * 2 / tp_size ), hidden_size ],
94
77
]
95
78
)
96
- if has_shared_experts :
97
- logger .debug ("Generating kn pairs for models with shared experts." )
98
- gemm_kn_pairs .extend (
99
- [
100
- [hidden_size , int (moe_intermediate_size * 4 )],
101
- [int (moe_intermediate_size * 2 ), hidden_size ],
102
- ]
103
- )
104
- elif tp_size > 1 and ep_size > 1 :
105
- raise ValueError ("Not supported to enable EP and TP at the same time for now." )
106
- else :
107
- raise ValueError ("Please check the tensor parallel size and expert parallel size." )
108
79
109
80
return (
110
81
gemm_kn_pairs ,
0 commit comments