Skip to content

Commit 5d17734

Browse files
authored
[CPU] Block-wise QMoE kernel for CPU (#26009)
This PR adds block-wise quant kernel for QMoE CPU
1 parent 203e2b2 commit 5d17734

File tree

6 files changed

+1268
-227
lines changed

6 files changed

+1268
-227
lines changed

onnxruntime/contrib_ops/cpu/moe/moe_helper.h

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ Status CheckInputs(MoEParameters& parameters,
4949
const Tensor* fc3_experts_bias, // optional
5050
const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
5151
const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
52-
const bool is_fused_swiglu) {
52+
const bool is_fused_swiglu,
53+
const int64_t block_size = 0) { // block size for block-wise quantization
5354
// Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later.
5455
ASSERT_TENSOR_2D_OR_3D(input);
5556
ASSERT_TENSOR_3D(fc1_experts_weights);
@@ -90,9 +91,63 @@ Status CheckInputs(MoEParameters& parameters,
9091
CHECK_TENSOR_SHAPE(fc2_experts_bias, num_experts, hidden_size);
9192
CHECK_TENSOR_SHAPE(fc3_experts_bias, num_experts, inter_size);
9293

93-
CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size);
94-
CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size);
95-
CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size);
94+
// Validate scale tensors: Handle both row-wise and block-wise quantization flexibly
95+
// First, detect the actual quantization method from the tensor shapes
96+
bool is_row_wise_quantization = true;
97+
if (fc1_experts_scales != nullptr) {
98+
const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims();
99+
if (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1) {
100+
is_row_wise_quantization = false;
101+
}
102+
}
103+
104+
if (block_size > 0 && !is_row_wise_quantization) {
105+
// Block-wise quantization: 3D scale tensors
106+
// For block-wise quantization, we calculate the number of blocks using ceiling division
107+
// to handle cases where the dimension is not perfectly divisible by block_size
108+
const int64_t fc1_blocks_per_row = (hidden_size + block_size - 1) / block_size;
109+
const int64_t fc2_blocks_per_row = (inter_size + block_size - 1) / block_size;
110+
const int64_t fc3_blocks_per_row = (hidden_size + block_size - 1) / block_size;
111+
112+
CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, fc1_blocks_per_row);
113+
CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, fc2_blocks_per_row);
114+
CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, fc3_blocks_per_row);
115+
} else {
116+
// Row-wise quantization: 2D scale tensors or 3D with last dimension = 1
117+
// Handle both {num_experts, features} and {num_experts, features, 1} shapes
118+
if (fc1_experts_scales != nullptr) {
119+
const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims();
120+
if (fc1_scales_dims.size() == 2) {
121+
CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size);
122+
} else if (fc1_scales_dims.size() == 3) {
123+
CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, 1);
124+
} else {
125+
ORT_THROW("fc1_experts_scales must be 2D or 3D tensor");
126+
}
127+
}
128+
129+
if (fc2_experts_scales != nullptr) {
130+
const auto& fc2_scales_dims = fc2_experts_scales->Shape().GetDims();
131+
if (fc2_scales_dims.size() == 2) {
132+
CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size);
133+
} else if (fc2_scales_dims.size() == 3) {
134+
CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, 1);
135+
} else {
136+
ORT_THROW("fc2_experts_scales must be 2D or 3D tensor");
137+
}
138+
}
139+
140+
if (fc3_experts_scales != nullptr) {
141+
const auto& fc3_scales_dims = fc3_experts_scales->Shape().GetDims();
142+
if (fc3_scales_dims.size() == 2) {
143+
CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size);
144+
} else if (fc3_scales_dims.size() == 3) {
145+
CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, 1);
146+
} else {
147+
ORT_THROW("fc3_experts_scales must be 2D or 3D tensor");
148+
}
149+
}
150+
}
96151

97152
if (fc3_experts_weights == nullptr) {
98153
ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr);

0 commit comments

Comments
 (0)