@@ -13,14 +13,16 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16+ #include < glog/logging.h>
17+
1618#include " mlu_ops_api.h"
1719
1820namespace {
1921torch::Tensor create_group_gemm_output (
2022 const torch::Tensor& a,
2123 const torch::Tensor& b,
2224 const torch::Tensor& group_list,
23- c10 ::ScalarType dtype = c10 ::ScalarType::BFloat16) {
25+ torch ::ScalarType dtype = torch ::ScalarType::BFloat16) {
2426 torch::TensorOptions target_options = a.options ().dtype (dtype);
2527 if (b.dim () != 2 ) {
2628 return torch::empty ({a.size (0 ), b.size (1 )}, target_options);
@@ -77,9 +79,9 @@ torch::Tensor fused_moe(
7779 // check smooth quant variables
7880 bool all_present = input_smooth && act_smooth && w1_scale && w2_scale;
7981 bool all_none = !input_smooth && !act_smooth && !w1_scale && !w2_scale;
80- TORCH_CHECK (all_none || all_present,
81- " input_smooth, act_smooth, w1_scale and w2_scale must be present "
82- " or absent at the same time." ) ;
82+ CHECK (all_none || all_present)
83+ << " input_smooth, act_smooth, w1_scale and w2_scale must be present or "
84+ " absent at the same time." ;
8385 bool is_smoothquant = all_present;
8486 int64_t expert_num = gating_output_2d.size (-1 );
8587 int64_t expert_size = w1.size (0 );
@@ -97,8 +99,8 @@ torch::Tensor fused_moe(
9799 num_expert_group,
98100 topk_group,
99101 renormalize,
100- std::nullopt , // mask
101- " topk_logit" , // normed_by
102+ /* mask= */ std::nullopt ,
103+ /* normed_by= */ " topk_logit" ,
102104 scoring_func,
103105 route_scale,
104106 e_score_correction_bias,
@@ -123,21 +125,19 @@ torch::Tensor fused_moe(
123125 // w8a8 path: quantize input hidden states directly (fused with
124126 // moe_expand_input)
125127 std::tie (expand_hidden_states, input_scale) =
126- xllm::kernel::mlu::scaled_quantize (
127- hidden_states_2d, // Use original hidden_states_2d instead of
128- // expand_hidden_states
129- input_smooth.value (),
130- std::nullopt , // zero
131- token_count_slice,
132- expand_idx,
133- gather_index_start_position,
134- std::nullopt , // output
135- std::nullopt , // output_scale
136- " none" , // act_mode
137- 1.0 , // active_coef
138- false , // is_gated
139- torch::kChar // quant_type
140- );
128+ scaled_quantize (hidden_states_2d, // Use original hidden_states_2d
129+ // instead of expand_hidden_states
130+ input_smooth.value (),
131+ /* zero=*/ std::nullopt ,
132+ token_count_slice,
133+ expand_idx,
134+ gather_index_start_position,
135+ /* output=*/ std::nullopt ,
136+ /* output_scale=*/ std::nullopt ,
137+ /* act_mode=*/ " none" ,
138+ /* active_coef=*/ 1.0 ,
139+ /* is_gated=*/ false ,
140+ /* quant_type=*/ torch::kChar );
141141 } else {
142142 // bf16/fp32 path: expand input hidden states
143143 expand_hidden_states = tmo::torch_api::moe_expand_input (hidden_states_2d,
@@ -153,26 +153,25 @@ torch::Tensor fused_moe(
153153 // Unified group_gemm call using input_scale/w1_scale/quant_flag only if
154154 // present
155155 tmo::torch_api::group_gemm (
156- expand_hidden_states, // a
157- w1, // b
158- token_count_slice, // m_list
159- gemm1_out, // d
160- std::nullopt , // gather_idx
161- std::nullopt , // c
162- std::nullopt , // alpha
163- std::nullopt , // beta
164- input_scale.defined () ? std::make_optional (input_scale)
165- : std::nullopt , // a_scale
166- w1_scale.has_value () ? std::make_optional (w1_scale.value ())
167- : std::nullopt , // b_scale
168- std::nullopt , // bias
169- w1_quant_flag.has_value () ? w1_quant_flag : std::nullopt , // quant_flag
170- std::nullopt , // b_offset
171- std::nullopt , // tile_config
172- tokens, // max_dim
173- false , // trans_a
174- true // trans_b
175- );
156+ expand_hidden_states,
157+ w1,
158+ token_count_slice,
159+ gemm1_out,
160+ /* gather_idx=*/ std::nullopt ,
161+ /* c=*/ std::nullopt ,
162+ /* alpha=*/ std::nullopt ,
163+ /* beta=*/ std::nullopt ,
164+ /* a_scale=*/ input_scale.defined () ? std::make_optional (input_scale)
165+ : std::nullopt ,
166+ /* b_scale=*/ w1_scale.has_value () ? std::make_optional (w1_scale.value ())
167+ : std::nullopt ,
168+ /* bias=*/ std::nullopt ,
169+ /* quant_flag=*/ w1_quant_flag.has_value () ? w1_quant_flag : std::nullopt ,
170+ /* b_offset=*/ std::nullopt ,
171+ /* tile_config=*/ std::nullopt ,
172+ /* max_dim=*/ tokens,
173+ /* trans_a=*/ false ,
174+ /* trans_b=*/ true );
176175
177176 // prepare the parameters for the second group gemm
178177 torch::Tensor act_out;
@@ -186,20 +185,18 @@ torch::Tensor fused_moe(
186185 // Quantize gemm1_out directly (fused with active operation) using reused
187186 // tensors
188187 auto [quantized_activation, activation_scale] =
189- xllm::kernel::mlu::scaled_quantize (
190- gemm1_out,
191- act_smooth.value (),
192- std::nullopt , // zero
193- token_count_slice,
194- std::nullopt , // gather_index
195- std::nullopt , // gather_index_start_position
196- act_out, // output - reuse from quantized_input
197- act_out_scale, // output_scale - reuse from input_scale
198- act_mode, // act_mode
199- 1.0 , // active_coef
200- gated, // is_gated
201- torch::kChar // quant_type
202- );
188+ scaled_quantize (gemm1_out,
189+ act_smooth.value (),
190+ /* zero=*/ std::nullopt ,
191+ /* token_count=*/ token_count_slice,
192+ /* gather_index=*/ std::nullopt ,
193+ /* gather_index_start_position=*/ std::nullopt ,
194+ act_out, // output - reuse from quantized_input
195+ act_out_scale, // output_scale - reuse from input_scale
196+ /* act_mode=*/ act_mode,
197+ /* active_coef=*/ 1.0 ,
198+ /* is_gated=*/ gated,
199+ /* quant_type=*/ torch::kChar );
203200 act_out = quantized_activation;
204201 act_out_scale = activation_scale;
205202 } else {
@@ -221,26 +218,25 @@ torch::Tensor fused_moe(
221218 // Unified group_gemm call, now only checks the existance of
222219 // input_scale/w1_scale for smoothquant
223220 tmo::torch_api::group_gemm (
224- act_out, // a
225- w2, // b
226- token_count_slice, // m_list
227- gemm2_out, // d
228- std::nullopt , // gather_idx
229- std::nullopt , // c
230- std::nullopt , // alpha
231- std::nullopt , // beta
221+ act_out,
222+ w2,
223+ token_count_slice,
224+ gemm2_out,
225+ /* gather_idx= */ std::nullopt ,
226+ /* c= */ std::nullopt ,
227+ /* alpha= */ std::nullopt ,
228+ /* beta= */ std::nullopt ,
232229 act_out_scale.defined () ? std::make_optional (act_out_scale)
233230 : std::nullopt , // a_scale
234231 w2_scale.has_value () ? std::make_optional (w2_scale.value ())
235- : std::nullopt , // b_scale
236- std::nullopt , // bias
232+ : std::nullopt , // b_scale
233+ /* bias= */ std::nullopt ,
237234 w2_quant_flag.has_value () ? w2_quant_flag : std::nullopt , // quant_flag
238- std::nullopt , // b_offset
239- std::nullopt , // tile_config
240- tokens, // max_dim
241- false , // trans_a
242- true // trans_b
243- );
235+ /* b_offset=*/ std::nullopt ,
236+ /* tile_config=*/ std::nullopt ,
237+ tokens, // max_dim
238+ /* trans_a=*/ false ,
239+ /* trans_b=*/ true );
244240
245241 auto output = torch::empty ({reduce_weight.size (0 ), gemm2_out.size (1 )},
246242 gemm2_out.options ());
0 commit comments