Skip to content

Commit 1df7760

Browse files
feat: support w8a8 mlp on mlu for deepseek v3.2 prerequisite. (#290)
Co-authored-by: phantomlei <[email protected]>
1 parent 12dde6b commit 1df7760

22 files changed

+1328
-91
lines changed

xllm/core/kernels/mlu/mlu_ops_api.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,37 @@ torch::Tensor fused_moe(
154154
int shared_expert_num,
155155
const std::string& parallel_mode);
156156

157+
std::tuple<torch::Tensor, torch::Tensor> scaled_quantize(
158+
const torch::Tensor& x,
159+
const torch::Tensor& smooth,
160+
const std::optional<torch::Tensor>& zero = std::nullopt,
161+
const std::optional<torch::Tensor>& token_count = std::nullopt,
162+
const std::optional<torch::Tensor>& gather_index = std::nullopt,
163+
const std::optional<torch::Tensor>& gather_index_start_position =
164+
std::nullopt,
165+
const std::optional<torch::Tensor>& output = std::nullopt,
166+
const std::optional<torch::Tensor>& output_scale = std::nullopt,
167+
const std::string& act_mode = "none",
168+
double active_coef = 1.0,
169+
bool is_gated = false,
170+
at::ScalarType quant_type = at::kChar);
171+
172+
torch::Tensor scaled_matmul(
173+
const torch::Tensor& a,
174+
const torch::Tensor& b,
175+
const std::optional<torch::Tensor>& a_scale,
176+
const torch::Tensor& b_scale,
177+
c10::ScalarType output_dtype,
178+
const std::optional<torch::Tensor>& bias = std::nullopt,
179+
const std::optional<torch::Tensor>& c = std::nullopt,
180+
const std::string& act_mode = "none",
181+
int64_t quant_bit_size = 8,
182+
double alpha = 1.0,
183+
double beta = 1.0,
184+
bool use_hp_active = false,
185+
int64_t a_quant_bit_size = -1,
186+
const std::optional<torch::Tensor>& a_calib = std::nullopt,
187+
const std::optional<torch::Tensor>& b_calib = std::nullopt,
188+
const std::optional<torch::Tensor>& output = std::nullopt);
189+
157190
} // namespace xllm::kernel::mlu
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
18+
namespace xllm::kernel::mlu {
19+
20+
torch::Tensor scaled_matmul(
21+
const torch::Tensor& a,
22+
const torch::Tensor& b,
23+
const std::optional<torch::Tensor>& a_scale,
24+
const torch::Tensor& b_scale,
25+
c10::ScalarType output_dtype,
26+
const std::optional<torch::Tensor>& bias /* = c10::nullopt */,
27+
const std::optional<torch::Tensor>& c /* = c10::nullopt */,
28+
const std::string& act_mode /* = "none" */,
29+
int64_t quant_bit_size /* = 8 */,
30+
double alpha /* = 1.0 */,
31+
double beta /* = 1.0 */,
32+
bool use_hp_active /* = false */,
33+
int64_t a_quant_bit_size /* = -1 */,
34+
const std::optional<torch::Tensor>& a_calib /* = c10::nullopt */,
35+
const std::optional<torch::Tensor>& b_calib /* = c10::nullopt */,
36+
const std::optional<torch::Tensor>& output /* = c10::nullopt */
37+
) {
38+
// Check: only support w8a8 quantization for now.
39+
TORCH_CHECK(quant_bit_size == 8 && a_quant_bit_size == 8,
40+
"scaled_matmul only supports w8a8 quantization (quant_bit_size "
41+
"== 8, a_quant_bit_size == 8) for now. "
42+
"Got quant_bit_size = ",
43+
quant_bit_size,
44+
", a_quant_bit_size = ",
45+
a_quant_bit_size,
46+
".");
47+
48+
// Only support smooth_quant algorithm for now
49+
std::string quant_algo = "smooth_quant";
50+
std::string a_quant_layout = (a_scale.value().dim() == 1)
51+
? "quantize_per_token"
52+
: "quantize_group_wise";
53+
std::string b_quant_layout = "quantize_per_channel";
54+
if (b_scale.dim() > 1) {
55+
if (b_scale.size(0) < b.size(0)) {
56+
b_quant_layout = "quantize_per_block";
57+
} else {
58+
b_quant_layout = "quantize_group_wise";
59+
}
60+
}
61+
std::optional<torch::Tensor> gemm_output_scale = c10::nullopt;
62+
63+
at::ScalarType torch_half = at::ScalarType::Half;
64+
at::ScalarType torch_bfloat16 = at::ScalarType::BFloat16;
65+
66+
TORCH_CHECK(output_dtype == torch_half || output_dtype == torch_bfloat16,
67+
"output dtype must be half or bfloat16, but got: ",
68+
output_dtype,
69+
".");
70+
71+
// Select output tensor
72+
torch::Tensor output_tensor;
73+
if (output.has_value()) {
74+
output_tensor = output.value();
75+
} else {
76+
output_tensor = at::empty(
77+
{a.size(0), b.size(0)},
78+
torch::TensorOptions().dtype(output_dtype).device(a.device()));
79+
}
80+
81+
// Call underlying kernel for smooth_quant
82+
tmo::torch_api::scaled_matmul(output_tensor,
83+
a,
84+
b,
85+
a_scale,
86+
c10::nullopt, // a_zero
87+
a_calib,
88+
b_scale,
89+
c10::nullopt, // b_zero
90+
b_calib,
91+
bias,
92+
c,
93+
c10::nullopt, // c_scale
94+
c10::nullopt, // c_zero
95+
gemm_output_scale,
96+
c10::nullopt, // gemm_output_zero
97+
quant_algo,
98+
a_quant_layout,
99+
b_quant_layout,
100+
a_quant_bit_size,
101+
quant_bit_size,
102+
act_mode,
103+
use_hp_active,
104+
1.0, // act_coef
105+
alpha,
106+
beta,
107+
false, // trans_a
108+
true // trans_b
109+
);
110+
return output_tensor;
111+
}
112+
113+
} // namespace xllm::kernel::mlu
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
18+
namespace xllm::kernel::mlu {
19+
std::tuple<torch::Tensor, torch::Tensor> scaled_quantize(
20+
const torch::Tensor& x,
21+
const torch::Tensor& smooth,
22+
const std::optional<torch::Tensor>& zero /* = c10::nullopt */,
23+
const std::optional<torch::Tensor>& token_count /* = c10::nullopt */,
24+
const std::optional<torch::Tensor>& gather_index /* = c10::nullopt */,
25+
const std::optional<torch::Tensor>&
26+
gather_index_start_position /* = c10::nullopt */,
27+
const std::optional<torch::Tensor>& output /* = c10::nullopt */,
28+
const std::optional<torch::Tensor>& output_scale /* = c10::nullopt */,
29+
const std::string& act_mode /* = "none" */,
30+
double active_coef /* = 1.0 */,
31+
bool is_gated /* = false */,
32+
at::ScalarType quant_type /* = at::kChar */
33+
) {
34+
// If act_mode is "none", override is_gated to false
35+
bool gated = is_gated;
36+
if (act_mode == "none") {
37+
gated = false;
38+
}
39+
40+
// Determine output shape
41+
auto x_sizes = x.sizes();
42+
std::vector<int64_t> output_shape(x_sizes.begin(), x_sizes.end());
43+
std::vector<int64_t> output_scale_shape(x_sizes.begin(), x_sizes.end() - 1);
44+
45+
// Adjust output shape based on gather_index
46+
if (gather_index.has_value()) {
47+
int64_t output_tokens = gather_index.value().size(0);
48+
output_shape[0] = output_tokens;
49+
output_scale_shape[0] = output_tokens;
50+
}
51+
52+
// Adjust output shape for gated activation
53+
if (gated) {
54+
// For gated, output is [..., C//2]
55+
output_shape.back() = output_shape.back() / 2;
56+
}
57+
58+
// Allocate output tensors
59+
torch::Tensor result_output;
60+
torch::Tensor result_output_scale;
61+
62+
if (output.has_value()) {
63+
result_output = output.value();
64+
} else {
65+
result_output = at::empty(output_shape, x.options().dtype(quant_type));
66+
}
67+
68+
if (output_scale.has_value()) {
69+
result_output_scale = output_scale.value();
70+
} else {
71+
result_output_scale =
72+
at::empty(output_scale_shape, x.options().dtype(at::kFloat));
73+
}
74+
75+
// Call underlying MLU kernel
76+
tmo::torch_api::scaled_quantize(x,
77+
result_output,
78+
result_output_scale,
79+
smooth,
80+
zero,
81+
token_count,
82+
gather_index,
83+
gather_index_start_position,
84+
/*scale_upper_bound*/ c10::nullopt,
85+
std::string("dynamic_per_token"),
86+
act_mode,
87+
active_coef,
88+
gated);
89+
90+
return std::make_tuple(result_output, result_output_scale);
91+
}
92+
93+
} // namespace xllm::kernel::mlu

xllm/core/kernels/ops_api.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,5 +183,48 @@ torch::Tensor fused_moe(FusedMoEParams& params) {
183183
throw std::runtime_error("fused_moe not implemented");
184184
#endif
185185
}
186+
187+
std::tuple<torch::Tensor, torch::Tensor> scaled_quantize(
188+
ScaledQuantizeParams& params) {
189+
#if defined(USE_MLU)
190+
return mlu::scaled_quantize(params.x,
191+
params.smooth,
192+
params.zero,
193+
params.token_count,
194+
params.gather_index,
195+
params.gather_index_start_position,
196+
params.output,
197+
params.output_scale,
198+
params.act_mode,
199+
params.active_coef,
200+
params.is_gated,
201+
params.quant_type);
202+
#else
203+
throw std::runtime_error("scaled_quantize not implemented");
204+
#endif
205+
}
206+
207+
torch::Tensor scaled_matmul(ScaledMatmulParams& params) {
208+
#if defined(USE_MLU)
209+
return mlu::scaled_matmul(params.a,
210+
params.b,
211+
params.a_scale,
212+
params.b_scale,
213+
params.output_dtype,
214+
params.bias,
215+
params.c,
216+
params.act_mode,
217+
params.quant_bit_size,
218+
params.alpha,
219+
params.beta,
220+
params.use_hp_active,
221+
params.a_quant_bit_size,
222+
params.a_calib,
223+
params.b_calib,
224+
params.output);
225+
#else
226+
throw std::runtime_error("scaled_matmul not implemented");
227+
#endif
228+
}
186229
} // namespace kernel
187230
} // namespace xllm

xllm/core/kernels/ops_api.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,10 @@ torch::Tensor matmul(MatmulParams& params);
4040

4141
torch::Tensor fused_moe(FusedMoEParams& params);
4242

43+
std::tuple<torch::Tensor, torch::Tensor> scaled_quantize(
44+
ScaledQuantizeParams& params);
45+
46+
torch::Tensor scaled_matmul(ScaledMatmulParams& params);
47+
4348
} // namespace kernel
4449
} // namespace xllm

xllm/core/kernels/param.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,5 +170,41 @@ struct FusedMoEParams {
170170
int shared_expert_num = 0;
171171
std::string parallel_mode = "ep";
172172
};
173+
174+
// Per token smooth quantize parameters
175+
struct ScaledQuantizeParams {
176+
torch::Tensor x;
177+
torch::Tensor smooth;
178+
std::optional<torch::Tensor> zero = std::nullopt;
179+
std::optional<torch::Tensor> token_count = std::nullopt;
180+
std::optional<torch::Tensor> gather_index = std::nullopt;
181+
std::optional<torch::Tensor> gather_index_start_position = std::nullopt;
182+
std::optional<torch::Tensor> output = std::nullopt;
183+
std::optional<torch::Tensor> output_scale = std::nullopt;
184+
std::string act_mode = "none";
185+
double active_coef = 1.0;
186+
bool is_gated = false;
187+
torch::ScalarType quant_type = torch::kChar;
188+
};
189+
190+
// Scaled matmul parameters
191+
struct ScaledMatmulParams {
192+
torch::Tensor a;
193+
torch::Tensor b;
194+
std::optional<torch::Tensor> a_scale = std::nullopt;
195+
torch::Tensor b_scale;
196+
torch::ScalarType output_dtype;
197+
std::optional<torch::Tensor> bias = std::nullopt;
198+
std::optional<torch::Tensor> c = std::nullopt;
199+
std::string act_mode = "none";
200+
int64_t quant_bit_size = 8;
201+
double alpha = 1.0;
202+
double beta = 1.0;
203+
bool use_hp_active = false;
204+
int64_t a_quant_bit_size = -1;
205+
std::optional<torch::Tensor> a_calib = std::nullopt;
206+
std::optional<torch::Tensor> b_calib = std::nullopt;
207+
std::optional<torch::Tensor> output = std::nullopt;
208+
};
173209
} // namespace kernel
174210
} // namespace xllm

xllm/core/layers/common/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,20 @@ cc_library(
4040
gflags::gflags
4141
torch
4242
)
43+
44+
# Add test for DenseMLP
45+
cc_test(
46+
NAME
47+
dense_mlp_test
48+
SRCS
49+
tests/dense_mlp_tests.cpp
50+
tests/tests_utils.cpp
51+
DEPS
52+
:common_layers
53+
:parallel_state
54+
:model
55+
:state_dict
56+
glog::glog
57+
torch
58+
GTest::gtest_main
59+
)

0 commit comments

Comments
 (0)