Skip to content

Commit 850ced1

Browse files
phantomlei3yq33victor
authored andcommitted
fix: optimize the style of codes.
1 parent 5d86da2 commit 850ced1

File tree

5 files changed

+126
-130
lines changed

5 files changed

+126
-130
lines changed

xllm/core/kernels/mlu/fused_moe.cpp

Lines changed: 67 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include <glog/logging.h>
17+
1618
#include "mlu_ops_api.h"
1719

1820
namespace {
1921
torch::Tensor create_group_gemm_output(
2022
const torch::Tensor& a,
2123
const torch::Tensor& b,
2224
const torch::Tensor& group_list,
23-
c10::ScalarType dtype = c10::ScalarType::BFloat16) {
25+
torch::ScalarType dtype = torch::ScalarType::BFloat16) {
2426
torch::TensorOptions target_options = a.options().dtype(dtype);
2527
if (b.dim() != 2) {
2628
return torch::empty({a.size(0), b.size(1)}, target_options);
@@ -77,9 +79,9 @@ torch::Tensor fused_moe(
7779
// check smooth quant variables
7880
bool all_present = input_smooth && act_smooth && w1_scale && w2_scale;
7981
bool all_none = !input_smooth && !act_smooth && !w1_scale && !w2_scale;
80-
TORCH_CHECK(all_none || all_present,
81-
"input_smooth, act_smooth, w1_scale and w2_scale must be present "
82-
"or absent at the same time.");
82+
CHECK(all_none || all_present)
83+
<< "input_smooth, act_smooth, w1_scale and w2_scale must be present or "
84+
"absent at the same time.";
8385
bool is_smoothquant = all_present;
8486
int64_t expert_num = gating_output_2d.size(-1);
8587
int64_t expert_size = w1.size(0);
@@ -97,8 +99,8 @@ torch::Tensor fused_moe(
9799
num_expert_group,
98100
topk_group,
99101
renormalize,
100-
std::nullopt, // mask
101-
"topk_logit", // normed_by
102+
/*mask=*/std::nullopt,
103+
/*normed_by=*/"topk_logit",
102104
scoring_func,
103105
route_scale,
104106
e_score_correction_bias,
@@ -123,21 +125,19 @@ torch::Tensor fused_moe(
123125
// w8a8 path: quantize input hidden states directly (fused with
124126
// moe_expand_input)
125127
std::tie(expand_hidden_states, input_scale) =
126-
xllm::kernel::mlu::scaled_quantize(
127-
hidden_states_2d, // Use original hidden_states_2d instead of
128-
// expand_hidden_states
129-
input_smooth.value(),
130-
std::nullopt, // zero
131-
token_count_slice,
132-
expand_idx,
133-
gather_index_start_position,
134-
std::nullopt, // output
135-
std::nullopt, // output_scale
136-
"none", // act_mode
137-
1.0, // active_coef
138-
false, // is_gated
139-
torch::kChar // quant_type
140-
);
128+
scaled_quantize(hidden_states_2d, // Use original hidden_states_2d
129+
// instead of expand_hidden_states
130+
input_smooth.value(),
131+
/*zero=*/std::nullopt,
132+
token_count_slice,
133+
expand_idx,
134+
gather_index_start_position,
135+
/*output=*/std::nullopt,
136+
/*output_scale=*/std::nullopt,
137+
/*act_mode=*/"none",
138+
/*active_coef=*/1.0,
139+
/*is_gated=*/false,
140+
/*quant_type=*/torch::kChar);
141141
} else {
142142
// bf16/fp32 path: expand input hidden states
143143
expand_hidden_states = tmo::torch_api::moe_expand_input(hidden_states_2d,
@@ -153,26 +153,25 @@ torch::Tensor fused_moe(
153153
// Unified group_gemm call using input_scale/w1_scale/quant_flag only if
154154
// present
155155
tmo::torch_api::group_gemm(
156-
expand_hidden_states, // a
157-
w1, // b
158-
token_count_slice, // m_list
159-
gemm1_out, // d
160-
std::nullopt, // gather_idx
161-
std::nullopt, // c
162-
std::nullopt, // alpha
163-
std::nullopt, // beta
164-
input_scale.defined() ? std::make_optional(input_scale)
165-
: std::nullopt, // a_scale
166-
w1_scale.has_value() ? std::make_optional(w1_scale.value())
167-
: std::nullopt, // b_scale
168-
std::nullopt, // bias
169-
w1_quant_flag.has_value() ? w1_quant_flag : std::nullopt, // quant_flag
170-
std::nullopt, // b_offset
171-
std::nullopt, // tile_config
172-
tokens, // max_dim
173-
false, // trans_a
174-
true // trans_b
175-
);
156+
expand_hidden_states,
157+
w1,
158+
token_count_slice,
159+
gemm1_out,
160+
/*gather_idx=*/std::nullopt,
161+
/*c=*/std::nullopt,
162+
/*alpha=*/std::nullopt,
163+
/*beta=*/std::nullopt,
164+
/*a_scale=*/input_scale.defined() ? std::make_optional(input_scale)
165+
: std::nullopt,
166+
/*b_scale=*/w1_scale.has_value() ? std::make_optional(w1_scale.value())
167+
: std::nullopt,
168+
/*bias=*/std::nullopt,
169+
/*quant_flag=*/w1_quant_flag.has_value() ? w1_quant_flag : std::nullopt,
170+
/*b_offset=*/std::nullopt,
171+
/*tile_config=*/std::nullopt,
172+
/*max_dim=*/tokens,
173+
/*trans_a=*/false,
174+
/*trans_b=*/true);
176175

177176
// prepare the parameters for the second group gemm
178177
torch::Tensor act_out;
@@ -186,20 +185,18 @@ torch::Tensor fused_moe(
186185
// Quantize gemm1_out directly (fused with active operation) using reused
187186
// tensors
188187
auto [quantized_activation, activation_scale] =
189-
xllm::kernel::mlu::scaled_quantize(
190-
gemm1_out,
191-
act_smooth.value(),
192-
std::nullopt, // zero
193-
token_count_slice,
194-
std::nullopt, // gather_index
195-
std::nullopt, // gather_index_start_position
196-
act_out, // output - reuse from quantized_input
197-
act_out_scale, // output_scale - reuse from input_scale
198-
act_mode, // act_mode
199-
1.0, // active_coef
200-
gated, // is_gated
201-
torch::kChar // quant_type
202-
);
188+
scaled_quantize(gemm1_out,
189+
act_smooth.value(),
190+
/*zero=*/std::nullopt,
191+
/*token_count=*/token_count_slice,
192+
/*gather_index=*/std::nullopt,
193+
/*gather_index_start_position=*/std::nullopt,
194+
act_out, // output - reuse from quantized_input
195+
act_out_scale, // output_scale - reuse from input_scale
196+
/*act_mode=*/act_mode,
197+
/*active_coef=*/1.0,
198+
/*is_gated=*/gated,
199+
/*quant_type=*/torch::kChar);
203200
act_out = quantized_activation;
204201
act_out_scale = activation_scale;
205202
} else {
@@ -221,26 +218,25 @@ torch::Tensor fused_moe(
221218
// Unified group_gemm call, now only checks the existance of
222219
// input_scale/w1_scale for smoothquant
223220
tmo::torch_api::group_gemm(
224-
act_out, // a
225-
w2, // b
226-
token_count_slice, // m_list
227-
gemm2_out, // d
228-
std::nullopt, // gather_idx
229-
std::nullopt, // c
230-
std::nullopt, // alpha
231-
std::nullopt, // beta
221+
act_out,
222+
w2,
223+
token_count_slice,
224+
gemm2_out,
225+
/*gather_idx=*/std::nullopt,
226+
/*c=*/std::nullopt,
227+
/*alpha=*/std::nullopt,
228+
/*beta=*/std::nullopt,
232229
act_out_scale.defined() ? std::make_optional(act_out_scale)
233230
: std::nullopt, // a_scale
234231
w2_scale.has_value() ? std::make_optional(w2_scale.value())
235-
: std::nullopt, // b_scale
236-
std::nullopt, // bias
232+
: std::nullopt, // b_scale
233+
/*bias=*/std::nullopt,
237234
w2_quant_flag.has_value() ? w2_quant_flag : std::nullopt, // quant_flag
238-
std::nullopt, // b_offset
239-
std::nullopt, // tile_config
240-
tokens, // max_dim
241-
false, // trans_a
242-
true // trans_b
243-
);
235+
/*b_offset=*/std::nullopt,
236+
/*tile_config=*/std::nullopt,
237+
tokens, // max_dim
238+
/*trans_a=*/false,
239+
/*trans_b=*/true);
244240

245241
auto output = torch::empty({reduce_weight.size(0), gemm2_out.size(1)},
246242
gemm2_out.options());

xllm/core/kernels/mlu/mlu_ops_api.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,14 @@ std::tuple<torch::Tensor, torch::Tensor> scaled_quantize(
170170
const std::string& act_mode = "none",
171171
double active_coef = 1.0,
172172
bool is_gated = false,
173-
at::ScalarType quant_type = at::kChar);
173+
torch::ScalarType quant_type = torch::kChar);
174174

175175
torch::Tensor scaled_matmul(
176176
const torch::Tensor& a,
177177
const torch::Tensor& b,
178178
const std::optional<torch::Tensor>& a_scale,
179179
const torch::Tensor& b_scale,
180-
c10::ScalarType output_dtype,
180+
torch::ScalarType output_dtype,
181181
const std::optional<torch::Tensor>& bias = std::nullopt,
182182
const std::optional<torch::Tensor>& c = std::nullopt,
183183
const std::string& act_mode = "none",

xllm/core/kernels/mlu/scaled_matmul.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ torch::Tensor scaled_matmul(
2222
const torch::Tensor& b,
2323
const std::optional<torch::Tensor>& a_scale,
2424
const torch::Tensor& b_scale,
25-
c10::ScalarType output_dtype,
26-
const std::optional<torch::Tensor>& bias /* = c10::nullopt */,
27-
const std::optional<torch::Tensor>& c /* = c10::nullopt */,
25+
torch::ScalarType output_dtype,
26+
const std::optional<torch::Tensor>& bias /* = std::nullopt */,
27+
const std::optional<torch::Tensor>& c /* = std::nullopt */,
2828
const std::string& act_mode /* = "none" */,
2929
int64_t quant_bit_size /* = 8 */,
3030
double alpha /* = 1.0 */,
3131
double beta /* = 1.0 */,
3232
bool use_hp_active /* = false */,
3333
int64_t a_quant_bit_size /* = -1 */,
34-
const std::optional<torch::Tensor>& a_calib /* = c10::nullopt */,
35-
const std::optional<torch::Tensor>& b_calib /* = c10::nullopt */,
36-
const std::optional<torch::Tensor>& output /* = c10::nullopt */
34+
const std::optional<torch::Tensor>& a_calib /* = std::nullopt */,
35+
const std::optional<torch::Tensor>& b_calib /* = std::nullopt */,
36+
const std::optional<torch::Tensor>& output /* = std::nullopt */
3737
) {
3838
// Check: only support w8a8 quantization for now.
3939
TORCH_CHECK(quant_bit_size == 8 && a_quant_bit_size == 8,
@@ -58,7 +58,7 @@ torch::Tensor scaled_matmul(
5858
b_quant_layout = "quantize_group_wise";
5959
}
6060
}
61-
std::optional<torch::Tensor> gemm_output_scale = c10::nullopt;
61+
std::optional<torch::Tensor> gemm_output_scale = std::nullopt;
6262

6363
at::ScalarType torch_half = at::ScalarType::Half;
6464
at::ScalarType torch_bfloat16 = at::ScalarType::BFloat16;
@@ -83,30 +83,29 @@ torch::Tensor scaled_matmul(
8383
a,
8484
b,
8585
a_scale,
86-
c10::nullopt, // a_zero
86+
/*a_zero=*/std::nullopt,
8787
a_calib,
8888
b_scale,
89-
c10::nullopt, // b_zero
89+
/*b_zero=*/std::nullopt,
9090
b_calib,
9191
bias,
9292
c,
93-
c10::nullopt, // c_scale
94-
c10::nullopt, // c_zero
93+
/*c_scale=*/std::nullopt,
94+
/*c_zero=*/std::nullopt,
9595
gemm_output_scale,
96-
c10::nullopt, // gemm_output_zero
96+
/*gemm_output_zero=*/std::nullopt,
9797
quant_algo,
9898
a_quant_layout,
9999
b_quant_layout,
100100
a_quant_bit_size,
101101
quant_bit_size,
102102
act_mode,
103103
use_hp_active,
104-
1.0, // act_coef
104+
/*act_coef=*/1.0,
105105
alpha,
106106
beta,
107-
false, // trans_a
108-
true // trans_b
109-
);
107+
/*trans_a=*/false,
108+
/*trans_b=*/true);
110109
return output_tensor;
111110
}
112111

xllm/core/kernels/mlu/scaled_quantize.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ namespace xllm::kernel::mlu {
1919
std::tuple<torch::Tensor, torch::Tensor> scaled_quantize(
2020
const torch::Tensor& x,
2121
const torch::Tensor& smooth,
22-
const std::optional<torch::Tensor>& zero /* = c10::nullopt */,
23-
const std::optional<torch::Tensor>& token_count /* = c10::nullopt */,
24-
const std::optional<torch::Tensor>& gather_index /* = c10::nullopt */,
22+
const std::optional<torch::Tensor>& zero /* = std::nullopt */,
23+
const std::optional<torch::Tensor>& token_count /* = std::nullopt */,
24+
const std::optional<torch::Tensor>& gather_index /* = std::nullopt */,
2525
const std::optional<torch::Tensor>&
26-
gather_index_start_position /* = c10::nullopt */,
27-
const std::optional<torch::Tensor>& output /* = c10::nullopt */,
28-
const std::optional<torch::Tensor>& output_scale /* = c10::nullopt */,
26+
gather_index_start_position /* = std::nullopt */,
27+
const std::optional<torch::Tensor>& output /* = std::nullopt */,
28+
const std::optional<torch::Tensor>& output_scale /* = std::nullopt */,
2929
const std::string& act_mode /* = "none" */,
3030
double active_coef /* = 1.0 */,
3131
bool is_gated /* = false */,
@@ -73,19 +73,20 @@ std::tuple<torch::Tensor, torch::Tensor> scaled_quantize(
7373
}
7474

7575
// Call underlying MLU kernel
76-
tmo::torch_api::scaled_quantize(x,
77-
result_output,
78-
result_output_scale,
79-
smooth,
80-
zero,
81-
token_count,
82-
gather_index,
83-
gather_index_start_position,
84-
/*scale_upper_bound*/ c10::nullopt,
85-
std::string("dynamic_per_token"),
86-
act_mode,
87-
active_coef,
88-
gated);
76+
tmo::torch_api::scaled_quantize(
77+
x,
78+
result_output,
79+
result_output_scale,
80+
smooth,
81+
zero,
82+
token_count,
83+
gather_index,
84+
gather_index_start_position,
85+
/*scale_upper_bound*/ std::nullopt,
86+
/*quant_algo=*/std::string("dynamic_per_token"),
87+
act_mode,
88+
active_coef,
89+
gated);
8990

9091
return std::make_tuple(result_output, result_output_scale);
9192
}

0 commit comments

Comments
 (0)