Skip to content

Commit 1ae6ed8

Browse files
phantomlei3yq33victor
authored andcommitted
feat: support w8a8 moe on mlu for deepseek v3.2 prerequisite.
1 parent 8376517 commit 1ae6ed8

File tree

13 files changed

+919
-118
lines changed

13 files changed

+919
-118
lines changed

xllm/core/kernels/mlu/fused_moe.cpp

Lines changed: 152 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,17 @@ limitations under the License.
1616
#include "mlu_ops_api.h"
1717

1818
namespace {
19-
torch::Tensor create_group_gemm_output(const torch::Tensor& a,
20-
const torch::Tensor& b,
21-
const torch::Tensor& group_list) {
19+
torch::Tensor create_group_gemm_output(
20+
const torch::Tensor& a,
21+
const torch::Tensor& b,
22+
const torch::Tensor& group_list,
23+
c10::ScalarType dtype = c10::ScalarType::BFloat16) {
24+
torch::TensorOptions target_options = a.options().dtype(dtype);
2225
if (b.dim() != 2) {
23-
return torch::empty({a.size(0), b.size(1)}, a.options());
26+
return torch::empty({a.size(0), b.size(1)}, target_options);
2427
}
25-
return torch::empty({group_list.size(0), a.size(0), b.size(0)}, a.options());
28+
return torch::empty({group_list.size(0), a.size(0), b.size(0)},
29+
target_options);
2630
}
2731
} // namespace
2832

@@ -45,13 +49,16 @@ torch::Tensor fused_moe(
4549
bool gated,
4650
const std::string& act_mode,
4751
const std::string& scoring_func,
52+
int num_expert_group,
53+
int topk_group,
54+
double route_scale,
4855
int start_expert_id,
4956
int block_n,
5057
bool avg_moe,
5158
const std::optional<torch::Tensor>& class_reduce_weight,
5259
const std::optional<torch::Tensor>& class_expert_id,
53-
const std::optional<std::vector<bool>>& w1_quant_flag,
54-
const std::optional<std::vector<bool>>& w2_quant_flag,
60+
const std::optional<torch::List<int64_t>>& w1_quant_flag,
61+
const std::optional<torch::List<int64_t>>& w2_quant_flag,
5562
int world_size,
5663
int shared_expert_num,
5764
const std::string& parallel_mode) {
@@ -67,25 +74,33 @@ torch::Tensor fused_moe(
6774
residual_2d = residual.value().reshape({-1, residual.value().size(-1)});
6875
}
6976

77+
// check smooth quant variables
78+
bool all_present = input_smooth && act_smooth && w1_scale && w2_scale;
79+
bool all_none = !input_smooth && !act_smooth && !w1_scale && !w2_scale;
80+
TORCH_CHECK(all_none || all_present,
81+
"input_smooth, act_smooth, w1_scale and w2_scale must be present "
82+
"or absent at the same time.");
83+
bool is_smoothquant = all_present;
7084
int64_t expert_num = gating_output_2d.size(-1);
7185
int64_t expert_size = w1.size(0);
7286

73-
// softmax_topk
87+
// apply softmax_topk or sigmoid_topk
7488
auto reduce_weight = torch::empty(
7589
{gating_output_2d.size(0), topk},
7690
torch::dtype(torch::kFloat).device(gating_output_2d.device()));
7791
auto expert_id = torch::empty(
7892
{gating_output_2d.size(0), topk},
7993
torch::dtype(torch::kInt32).device(gating_output_2d.device()));
94+
8095
tmo::torch_api::moe_active_topk(gating_output_2d,
8196
topk,
82-
-1,
83-
0,
97+
num_expert_group,
98+
topk_group,
8499
renormalize,
85-
std::nullopt,
86-
"topk_logit",
100+
std::nullopt, // mask
101+
"topk_logit", // normed_by
87102
scoring_func,
88-
1.0,
103+
route_scale,
89104
e_score_correction_bias,
90105
reduce_weight,
91106
expert_id);
@@ -95,69 +110,137 @@ torch::Tensor fused_moe(
95110
auto combine_idx = output_vec[1];
96111
auto token_count = output_vec[2];
97112
auto cusum_token_count = output_vec[3];
98-
torch::Tensor expand_hidden_states =
99-
tmo::torch_api::moe_expand_input(hidden_states_2d,
100-
expand_idx,
101-
cusum_token_count,
102-
start_expert_id,
103-
expert_size);
104113

114+
// prepare the parameters for the first group gemm
105115
auto token_count_slice =
106116
token_count.slice(0, start_expert_id, start_expert_id + expert_size);
107-
torch::Tensor gemm1_out =
108-
create_group_gemm_output(expand_hidden_states, w1, token_count_slice);
109-
tmo::torch_api::group_gemm(expand_hidden_states,
110-
w1,
111-
token_count_slice,
112-
gemm1_out,
113-
std::nullopt, // expand_idx
114-
std::nullopt, // c
115-
std::nullopt, // alpha
116-
std::nullopt, // beta
117-
std::nullopt, // a_scale
118-
std::nullopt, // b_scale
119-
std::nullopt, // bias
120-
std::nullopt, // a_calibration
121-
std::nullopt, // b_calibration
122-
std::nullopt, // quant_flag
123-
tokens,
124-
false,
125-
true);
117+
auto gather_index_start_position =
118+
cusum_token_count.index({start_expert_id}).unsqueeze(0);
119+
torch::Tensor expand_hidden_states;
120+
torch::Tensor input_scale;
121+
122+
if (is_smoothquant) {
123+
// w8a8 path: quantize input hidden states directly (fused with
124+
// moe_expand_input)
125+
std::tie(expand_hidden_states, input_scale) =
126+
xllm::kernel::mlu::scaled_quantize(
127+
hidden_states_2d, // Use original hidden_states_2d instead of
128+
// expand_hidden_states
129+
input_smooth.value(),
130+
std::nullopt, // zero
131+
token_count_slice,
132+
expand_idx,
133+
gather_index_start_position,
134+
std::nullopt, // output
135+
std::nullopt, // output_scale
136+
"none", // act_mode
137+
1.0, // active_coef
138+
false, // is_gated
139+
torch::kChar // quant_type
140+
);
141+
} else {
142+
// bf16/fp32 path: expand input hidden states
143+
expand_hidden_states = tmo::torch_api::moe_expand_input(hidden_states_2d,
144+
expand_idx,
145+
cusum_token_count,
146+
start_expert_id,
147+
expert_size);
148+
}
149+
150+
torch::Tensor gemm1_out = create_group_gemm_output(
151+
expand_hidden_states, w1, token_count_slice, dtype.toScalarType());
126152

153+
// Unified group_gemm call using input_scale/w1_scale/quant_flag only if
154+
// present
155+
tmo::torch_api::group_gemm(
156+
expand_hidden_states, // a
157+
w1, // b
158+
token_count_slice, // m_list
159+
gemm1_out, // d
160+
std::nullopt, // gather_idx
161+
std::nullopt, // c
162+
std::nullopt, // alpha
163+
std::nullopt, // beta
164+
input_scale.defined() ? std::make_optional(input_scale)
165+
: std::nullopt, // a_scale
166+
w1_scale.has_value() ? std::make_optional(w1_scale.value())
167+
: std::nullopt, // b_scale
168+
std::nullopt, // bias
169+
w1_quant_flag.has_value() ? w1_quant_flag : std::nullopt, // quant_flag
170+
std::nullopt, // b_offset
171+
std::nullopt, // tile_config
172+
tokens, // max_dim
173+
false, // trans_a
174+
true // trans_b
175+
);
176+
177+
// prepare the parameters for the second group gemm
127178
torch::Tensor act_out;
128-
if (gated) {
129-
act_out = gemm1_out.slice(1, 0, gemm1_out.size(1) / 2);
179+
torch::Tensor act_out_scale;
180+
if (is_smoothquant) {
181+
// w8a8 path: reuse quantized_input and input_scale from first group_gemm
182+
act_out = gated ? expand_hidden_states.slice(1, 0, gemm1_out.size(1) / 2)
183+
: expand_hidden_states.slice(1, 0, gemm1_out.size(1));
184+
act_out_scale = input_scale.slice(0, 0, gemm1_out.size(0));
185+
186+
// Quantize gemm1_out directly (fused with active operation) using reused
187+
// tensors
188+
auto [quantized_activation, activation_scale] =
189+
xllm::kernel::mlu::scaled_quantize(
190+
gemm1_out,
191+
act_smooth.value(),
192+
std::nullopt, // zero
193+
token_count_slice,
194+
std::nullopt, // gather_index
195+
std::nullopt, // gather_index_start_position
196+
act_out, // output - reuse from quantized_input
197+
act_out_scale, // output_scale - reuse from input_scale
198+
act_mode, // act_mode
199+
1.0, // active_coef
200+
gated, // is_gated
201+
torch::kChar // quant_type
202+
);
203+
act_out = quantized_activation;
204+
act_out_scale = activation_scale;
130205
} else {
131-
act_out = gemm1_out;
206+
// bf16/fp32 path: apply activation function first
207+
act_out = gated ? gemm1_out.slice(1, 0, gemm1_out.size(1) / 2) : gemm1_out;
208+
tmo::torch_api::active(gemm1_out,
209+
act_out,
210+
bias1,
211+
cusum_token_count,
212+
act_mode,
213+
gated,
214+
start_expert_id,
215+
expert_size);
132216
}
133-
tmo::torch_api::active(gemm1_out,
134-
act_out,
135-
bias1,
136-
cusum_token_count,
137-
act_mode,
138-
gated,
139-
start_expert_id,
140-
expert_size);
141-
142-
torch::Tensor gemm2_out =
143-
create_group_gemm_output(act_out, w2, token_count_slice);
144-
tmo::torch_api::group_gemm(act_out,
145-
w2,
146-
token_count_slice,
147-
gemm2_out, // d
148-
std::nullopt, // expand_idx
149-
std::nullopt, // c
150-
std::nullopt, // alpha
151-
std::nullopt, // beta
152-
std::nullopt, // a_scale
153-
std::nullopt, // b_scale
154-
std::nullopt, // bias
155-
std::nullopt, // a_calibration
156-
std::nullopt, // b_calibration
157-
std::nullopt, // quant_flag
158-
tokens,
159-
false,
160-
true);
217+
218+
torch::Tensor gemm2_out = create_group_gemm_output(
219+
act_out, w2, token_count_slice, dtype.toScalarType());
220+
221+
// Unified group_gemm call, now only checks the existance of
222+
// input_scale/w1_scale for smoothquant
223+
tmo::torch_api::group_gemm(
224+
act_out, // a
225+
w2, // b
226+
token_count_slice, // m_list
227+
gemm2_out, // d
228+
std::nullopt, // gather_idx
229+
std::nullopt, // c
230+
std::nullopt, // alpha
231+
std::nullopt, // beta
232+
act_out_scale.defined() ? std::make_optional(act_out_scale)
233+
: std::nullopt, // a_scale
234+
w2_scale.has_value() ? std::make_optional(w2_scale.value())
235+
: std::nullopt, // b_scale
236+
std::nullopt, // bias
237+
w2_quant_flag.has_value() ? w2_quant_flag : std::nullopt, // quant_flag
238+
std::nullopt, // b_offset
239+
std::nullopt, // tile_config
240+
tokens, // max_dim
241+
false, // trans_a
242+
true // trans_b
243+
);
161244

162245
auto output = torch::empty({reduce_weight.size(0), gemm2_out.size(1)},
163246
gemm2_out.options());

xllm/core/kernels/mlu/mlu_ops_api.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,16 @@ torch::Tensor fused_moe(
143143
bool gated,
144144
const std::string& act_mode,
145145
const std::string& scoring_func,
146+
int num_expert_group,
147+
int topk_group,
148+
double route_scale,
146149
int start_expert_id,
147150
int block_n,
148151
bool avg_moe,
149152
const std::optional<torch::Tensor>& class_reduce_weight,
150153
const std::optional<torch::Tensor>& class_expert_id,
151-
const std::optional<std::vector<bool>>& w1_quant_flag,
152-
const std::optional<std::vector<bool>>& w2_quant_flag,
154+
const std::optional<torch::List<int64_t>>& w1_quant_flag,
155+
const std::optional<torch::List<int64_t>>& w2_quant_flag,
153156
int world_size,
154157
int shared_expert_num,
155158
const std::string& parallel_mode);

xllm/core/kernels/ops_api.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ torch::Tensor fused_moe(FusedMoEParams& params) {
169169
params.gated,
170170
params.act_mode,
171171
params.scoring_func,
172+
params.num_expert_group,
173+
params.topk_group,
174+
params.route_scale,
172175
params.start_expert_id,
173176
params.block_n,
174177
params.avg_moe,

xllm/core/kernels/param.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,16 @@ struct FusedMoEParams {
159159
bool gated;
160160
std::string act_mode;
161161
std::string scoring_func = "softmax";
162+
int num_expert_group = -1;
163+
int topk_group = 0;
164+
double route_scale = 1.0;
162165
int start_expert_id = 0;
163166
int block_n = 0;
164167
bool avg_moe = false;
165168
std::optional<torch::Tensor> class_reduce_weight;
166169
std::optional<torch::Tensor> class_expert_id;
167-
std::optional<std::vector<bool>> w1_quant_flag;
168-
std::optional<std::vector<bool>> w2_quant_flag;
170+
std::optional<torch::List<int64_t>> w1_quant_flag;
171+
std::optional<torch::List<int64_t>> w2_quant_flag;
169172
int world_size = 0;
170173
int shared_expert_num = 0;
171174
std::string parallel_mode = "ep";

xllm/core/layers/common/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,20 @@ cc_test(
5959
torch
6060
GTest::gtest_main
6161
)
62+
63+
# Add test for FusedMoE
64+
cc_test(
65+
NAME
66+
fused_moe_test
67+
SRCS
68+
tests/fused_moe_tests.cpp
69+
tests/tests_utils.cpp
70+
DEPS
71+
:common_layers
72+
:parallel_state
73+
:model
74+
:state_dict
75+
glog::glog
76+
torch
77+
GTest::gtest_main
78+
)

xllm/core/layers/common/dense_mlp.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ DenseMLPImpl::DenseMLPImpl(int hidden_size,
3535
parallel_args_(parallel_args),
3636
hidden_act_(hidden_act) {
3737
// Check if using w8a8 smoothquant quantization
38-
is_per_token_smoothquant_ = quant_args.quant_method() == "smoothquant";
38+
is_smoothquant_ = quant_args.quant_method() == "smoothquant";
3939

40-
if (is_per_token_smoothquant_) {
40+
if (is_smoothquant_) {
4141
// Safety check: only w8a8 smoothquant is supported
4242
if (quant_args.bits() != 8 || !quant_args.activation_dynamic()) {
4343
LOG(FATAL)
@@ -50,7 +50,7 @@ DenseMLPImpl::DenseMLPImpl(int hidden_size,
5050
// Determine extra args based on quantization mode
5151
FusedLinearExtraArgs gate_up_proj_extra_args("none", false);
5252
FusedLinearExtraArgs down_proj_extra_args("none", false);
53-
if (is_per_token_smoothquant_) {
53+
if (is_smoothquant_) {
5454
// For per-token smoothquant, use specific args
5555
down_proj_extra_args = FusedLinearExtraArgs(hidden_act_, is_gated_);
5656
}
@@ -84,7 +84,7 @@ torch::Tensor DenseMLPImpl::forward(const torch::Tensor& hidden_states) {
8484
// input shape: [num_tokens, hidden_size]
8585
auto gate_up = gate_up_proj_->forward(hidden_states);
8686

87-
if (is_per_token_smoothquant_) {
87+
if (is_smoothquant_) {
8888
// For w8a8 quantization, the active operation is fused with the down_proj
8989
return down_proj_->forward(gate_up);
9090
} else {

xllm/core/layers/common/dense_mlp.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class DenseMLPImpl : public torch::nn::Module {
4848
ParallelArgs parallel_args_;
4949
ColumnParallelLinear gate_up_proj_{nullptr};
5050
RowParallelLinear down_proj_{nullptr};
51-
bool is_per_token_smoothquant_;
51+
bool is_smoothquant_;
5252
std::string hidden_act_;
5353
};
5454
TORCH_MODULE(DenseMLP);

0 commit comments

Comments
 (0)