@@ -16,13 +16,17 @@ limitations under the License.
1616#include " mlu_ops_api.h"
1717
1818namespace {
19- torch::Tensor create_group_gemm_output (const torch::Tensor& a,
20- const torch::Tensor& b,
21- const torch::Tensor& group_list) {
19+ torch::Tensor create_group_gemm_output (
20+ const torch::Tensor& a,
21+ const torch::Tensor& b,
22+ const torch::Tensor& group_list,
23+ c10::ScalarType dtype = c10::ScalarType::BFloat16) {
24+ torch::TensorOptions target_options = a.options ().dtype (dtype);
2225 if (b.dim () != 2 ) {
23- return torch::empty ({a.size (0 ), b.size (1 )}, a. options () );
26+ return torch::empty ({a.size (0 ), b.size (1 )}, target_options );
2427 }
25- return torch::empty ({group_list.size (0 ), a.size (0 ), b.size (0 )}, a.options ());
28+ return torch::empty ({group_list.size (0 ), a.size (0 ), b.size (0 )},
29+ target_options);
2630}
2731} // namespace
2832
@@ -45,13 +49,16 @@ torch::Tensor fused_moe(
4549 bool gated,
4650 const std::string& act_mode,
4751 const std::string& scoring_func,
52+ int num_expert_group,
53+ int topk_group,
54+ double route_scale,
4855 int start_expert_id,
4956 int block_n,
5057 bool avg_moe,
5158 const std::optional<torch::Tensor>& class_reduce_weight,
5259 const std::optional<torch::Tensor>& class_expert_id,
53- const std::optional<std::vector< bool >>& w1_quant_flag,
54- const std::optional<std::vector< bool >>& w2_quant_flag,
60+ const std::optional<torch::List< int64_t >>& w1_quant_flag,
61+ const std::optional<torch::List< int64_t >>& w2_quant_flag,
5562 int world_size,
5663 int shared_expert_num,
5764 const std::string& parallel_mode) {
@@ -67,25 +74,33 @@ torch::Tensor fused_moe(
6774 residual_2d = residual.value ().reshape ({-1 , residual.value ().size (-1 )});
6875 }
6976
77+ // check smooth quant variables
78+ bool all_present = input_smooth && act_smooth && w1_scale && w2_scale;
79+ bool all_none = !input_smooth && !act_smooth && !w1_scale && !w2_scale;
80+ TORCH_CHECK (all_none || all_present,
81+ " input_smooth, act_smooth, w1_scale and w2_scale must be present "
82+ " or absent at the same time." );
83+ bool is_smoothquant = all_present;
7084 int64_t expert_num = gating_output_2d.size (-1 );
7185 int64_t expert_size = w1.size (0 );
7286
73- // softmax_topk
87+ // apply softmax_topk or sigmoid_topk
7488 auto reduce_weight = torch::empty (
7589 {gating_output_2d.size (0 ), topk},
7690 torch::dtype (torch::kFloat ).device (gating_output_2d.device ()));
7791 auto expert_id = torch::empty (
7892 {gating_output_2d.size (0 ), topk},
7993 torch::dtype (torch::kInt32 ).device (gating_output_2d.device ()));
94+
8095 tmo::torch_api::moe_active_topk (gating_output_2d,
8196 topk,
82- - 1 ,
83- 0 ,
97+ num_expert_group ,
98+ topk_group ,
8499 renormalize,
85- std::nullopt ,
86- " topk_logit" ,
100+ std::nullopt , // mask
101+ " topk_logit" , // normed_by
87102 scoring_func,
88- 1.0 ,
103+ route_scale ,
89104 e_score_correction_bias,
90105 reduce_weight,
91106 expert_id);
@@ -95,69 +110,137 @@ torch::Tensor fused_moe(
95110 auto combine_idx = output_vec[1 ];
96111 auto token_count = output_vec[2 ];
97112 auto cusum_token_count = output_vec[3 ];
98- torch::Tensor expand_hidden_states =
99- tmo::torch_api::moe_expand_input (hidden_states_2d,
100- expand_idx,
101- cusum_token_count,
102- start_expert_id,
103- expert_size);
104113
114+ // prepare the parameters for the first group gemm
105115 auto token_count_slice =
106116 token_count.slice (0 , start_expert_id, start_expert_id + expert_size);
107- torch::Tensor gemm1_out =
108- create_group_gemm_output (expand_hidden_states, w1, token_count_slice);
109- tmo::torch_api::group_gemm (expand_hidden_states,
110- w1,
111- token_count_slice,
112- gemm1_out,
113- std::nullopt , // expand_idx
114- std::nullopt , // c
115- std::nullopt , // alpha
116- std::nullopt , // beta
117- std::nullopt , // a_scale
118- std::nullopt , // b_scale
119- std::nullopt , // bias
120- std::nullopt , // a_calibration
121- std::nullopt , // b_calibration
122- std::nullopt , // quant_flag
123- tokens,
124- false ,
125- true );
117+ auto gather_index_start_position =
118+ cusum_token_count.index ({start_expert_id}).unsqueeze (0 );
119+ torch::Tensor expand_hidden_states;
120+ torch::Tensor input_scale;
121+
122+ if (is_smoothquant) {
123+ // w8a8 path: quantize input hidden states directly (fused with
124+ // moe_expand_input)
125+ std::tie (expand_hidden_states, input_scale) =
126+ xllm::kernel::mlu::scaled_quantize (
127+ hidden_states_2d, // Use original hidden_states_2d instead of
128+ // expand_hidden_states
129+ input_smooth.value (),
130+ std::nullopt , // zero
131+ token_count_slice,
132+ expand_idx,
133+ gather_index_start_position,
134+ std::nullopt , // output
135+ std::nullopt , // output_scale
136+ " none" , // act_mode
137+ 1.0 , // active_coef
138+ false , // is_gated
139+ torch::kChar // quant_type
140+ );
141+ } else {
142+ // bf16/fp32 path: expand input hidden states
143+ expand_hidden_states = tmo::torch_api::moe_expand_input (hidden_states_2d,
144+ expand_idx,
145+ cusum_token_count,
146+ start_expert_id,
147+ expert_size);
148+ }
149+
150+ torch::Tensor gemm1_out = create_group_gemm_output (
151+ expand_hidden_states, w1, token_count_slice, dtype.toScalarType ());
126152
153+ // Unified group_gemm call using input_scale/w1_scale/quant_flag only if
154+ // present
155+ tmo::torch_api::group_gemm (
156+ expand_hidden_states, // a
157+ w1, // b
158+ token_count_slice, // m_list
159+ gemm1_out, // d
160+ std::nullopt , // gather_idx
161+ std::nullopt , // c
162+ std::nullopt , // alpha
163+ std::nullopt , // beta
164+ input_scale.defined () ? std::make_optional (input_scale)
165+ : std::nullopt , // a_scale
166+ w1_scale.has_value () ? std::make_optional (w1_scale.value ())
167+ : std::nullopt , // b_scale
168+ std::nullopt , // bias
169+ w1_quant_flag.has_value () ? w1_quant_flag : std::nullopt , // quant_flag
170+ std::nullopt , // b_offset
171+ std::nullopt , // tile_config
172+ tokens, // max_dim
173+ false , // trans_a
174+ true // trans_b
175+ );
176+
177+ // prepare the parameters for the second group gemm
127178 torch::Tensor act_out;
128- if (gated) {
129- act_out = gemm1_out.slice (1 , 0 , gemm1_out.size (1 ) / 2 );
179+ torch::Tensor act_out_scale;
180+ if (is_smoothquant) {
181+ // w8a8 path: reuse quantized_input and input_scale from first group_gemm
182+ act_out = gated ? expand_hidden_states.slice (1 , 0 , gemm1_out.size (1 ) / 2 )
183+ : expand_hidden_states.slice (1 , 0 , gemm1_out.size (1 ));
184+ act_out_scale = input_scale.slice (0 , 0 , gemm1_out.size (0 ));
185+
186+ // Quantize gemm1_out directly (fused with active operation) using reused
187+ // tensors
188+ auto [quantized_activation, activation_scale] =
189+ xllm::kernel::mlu::scaled_quantize (
190+ gemm1_out,
191+ act_smooth.value (),
192+ std::nullopt , // zero
193+ token_count_slice,
194+ std::nullopt , // gather_index
195+ std::nullopt , // gather_index_start_position
196+ act_out, // output - reuse from quantized_input
197+ act_out_scale, // output_scale - reuse from input_scale
198+ act_mode, // act_mode
199+ 1.0 , // active_coef
200+ gated, // is_gated
201+ torch::kChar // quant_type
202+ );
203+ act_out = quantized_activation;
204+ act_out_scale = activation_scale;
130205 } else {
131- act_out = gemm1_out;
206+ // bf16/fp32 path: apply activation function first
207+ act_out = gated ? gemm1_out.slice (1 , 0 , gemm1_out.size (1 ) / 2 ) : gemm1_out;
208+ tmo::torch_api::active (gemm1_out,
209+ act_out,
210+ bias1,
211+ cusum_token_count,
212+ act_mode,
213+ gated,
214+ start_expert_id,
215+ expert_size);
132216 }
133- tmo::torch_api::active (gemm1_out,
134- act_out,
135- bias1,
136- cusum_token_count,
137- act_mode,
138- gated,
139- start_expert_id,
140- expert_size);
141-
142- torch::Tensor gemm2_out =
143- create_group_gemm_output (act_out, w2, token_count_slice);
144- tmo::torch_api::group_gemm (act_out,
145- w2,
146- token_count_slice,
147- gemm2_out, // d
148- std::nullopt , // expand_idx
149- std::nullopt , // c
150- std::nullopt , // alpha
151- std::nullopt , // beta
152- std::nullopt , // a_scale
153- std::nullopt , // b_scale
154- std::nullopt , // bias
155- std::nullopt , // a_calibration
156- std::nullopt , // b_calibration
157- std::nullopt , // quant_flag
158- tokens,
159- false ,
160- true );
217+
218+ torch::Tensor gemm2_out = create_group_gemm_output (
219+ act_out, w2, token_count_slice, dtype.toScalarType ());
220+
221+ // Unified group_gemm call, now only checks the existance of
222+ // input_scale/w1_scale for smoothquant
223+ tmo::torch_api::group_gemm (
224+ act_out, // a
225+ w2, // b
226+ token_count_slice, // m_list
227+ gemm2_out, // d
228+ std::nullopt , // gather_idx
229+ std::nullopt , // c
230+ std::nullopt , // alpha
231+ std::nullopt , // beta
232+ act_out_scale.defined () ? std::make_optional (act_out_scale)
233+ : std::nullopt , // a_scale
234+ w2_scale.has_value () ? std::make_optional (w2_scale.value ())
235+ : std::nullopt , // b_scale
236+ std::nullopt , // bias
237+ w2_quant_flag.has_value () ? w2_quant_flag : std::nullopt , // quant_flag
238+ std::nullopt , // b_offset
239+ std::nullopt , // tile_config
240+ tokens, // max_dim
241+ false , // trans_a
242+ true // trans_b
243+ );
161244
162245 auto output = torch::empty ({reduce_weight.size (0 ), gemm2_out.size (1 )},
163246 gemm2_out.options ());
0 commit comments