@@ -49,7 +49,8 @@ Status CheckInputs(MoEParameters& parameters,
49
49
const Tensor* fc3_experts_bias, // optional
50
50
const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
51
51
const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
52
- const bool is_fused_swiglu) {
52
+ const bool is_fused_swiglu,
53
+ const int64_t block_size = 0 ) { // block size for block-wise quantization
53
54
// Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later.
54
55
ASSERT_TENSOR_2D_OR_3D (input);
55
56
ASSERT_TENSOR_3D (fc1_experts_weights);
@@ -90,9 +91,63 @@ Status CheckInputs(MoEParameters& parameters,
90
91
CHECK_TENSOR_SHAPE (fc2_experts_bias, num_experts, hidden_size);
91
92
CHECK_TENSOR_SHAPE (fc3_experts_bias, num_experts, inter_size);
92
93
93
- CHECK_TENSOR_SHAPE (fc1_experts_scales, num_experts, fc1_inter_size);
94
- CHECK_TENSOR_SHAPE (fc2_experts_scales, num_experts, hidden_size);
95
- CHECK_TENSOR_SHAPE (fc3_experts_scales, num_experts, inter_size);
94
+ // Validate scale tensors: Handle both row-wise and block-wise quantization flexibly
95
+ // First, detect the actual quantization method from the tensor shapes
96
+ bool is_row_wise_quantization = true ;
97
+ if (fc1_experts_scales != nullptr ) {
98
+ const auto & fc1_scales_dims = fc1_experts_scales->Shape ().GetDims ();
99
+ if (fc1_scales_dims.size () == 3 && fc1_scales_dims[2 ] > 1 ) {
100
+ is_row_wise_quantization = false ;
101
+ }
102
+ }
103
+
104
+ if (block_size > 0 && !is_row_wise_quantization) {
105
+ // Block-wise quantization: 3D scale tensors
106
+ // For block-wise quantization, we calculate the number of blocks using ceiling division
107
+ // to handle cases where the dimension is not perfectly divisible by block_size
108
+ const int64_t fc1_blocks_per_row = (hidden_size + block_size - 1 ) / block_size;
109
+ const int64_t fc2_blocks_per_row = (inter_size + block_size - 1 ) / block_size;
110
+ const int64_t fc3_blocks_per_row = (hidden_size + block_size - 1 ) / block_size;
111
+
112
+ CHECK_TENSOR_SHAPE (fc1_experts_scales, num_experts, fc1_inter_size, fc1_blocks_per_row);
113
+ CHECK_TENSOR_SHAPE (fc2_experts_scales, num_experts, hidden_size, fc2_blocks_per_row);
114
+ CHECK_TENSOR_SHAPE (fc3_experts_scales, num_experts, inter_size, fc3_blocks_per_row);
115
+ } else {
116
+ // Row-wise quantization: 2D scale tensors or 3D with last dimension = 1
117
+ // Handle both {num_experts, features} and {num_experts, features, 1} shapes
118
+ if (fc1_experts_scales != nullptr ) {
119
+ const auto & fc1_scales_dims = fc1_experts_scales->Shape ().GetDims ();
120
+ if (fc1_scales_dims.size () == 2 ) {
121
+ CHECK_TENSOR_SHAPE (fc1_experts_scales, num_experts, fc1_inter_size);
122
+ } else if (fc1_scales_dims.size () == 3 ) {
123
+ CHECK_TENSOR_SHAPE (fc1_experts_scales, num_experts, fc1_inter_size, 1 );
124
+ } else {
125
+ ORT_THROW (" fc1_experts_scales must be 2D or 3D tensor" );
126
+ }
127
+ }
128
+
129
+ if (fc2_experts_scales != nullptr ) {
130
+ const auto & fc2_scales_dims = fc2_experts_scales->Shape ().GetDims ();
131
+ if (fc2_scales_dims.size () == 2 ) {
132
+ CHECK_TENSOR_SHAPE (fc2_experts_scales, num_experts, hidden_size);
133
+ } else if (fc2_scales_dims.size () == 3 ) {
134
+ CHECK_TENSOR_SHAPE (fc2_experts_scales, num_experts, hidden_size, 1 );
135
+ } else {
136
+ ORT_THROW (" fc2_experts_scales must be 2D or 3D tensor" );
137
+ }
138
+ }
139
+
140
+ if (fc3_experts_scales != nullptr ) {
141
+ const auto & fc3_scales_dims = fc3_experts_scales->Shape ().GetDims ();
142
+ if (fc3_scales_dims.size () == 2 ) {
143
+ CHECK_TENSOR_SHAPE (fc3_experts_scales, num_experts, inter_size);
144
+ } else if (fc3_scales_dims.size () == 3 ) {
145
+ CHECK_TENSOR_SHAPE (fc3_experts_scales, num_experts, inter_size, 1 );
146
+ } else {
147
+ ORT_THROW (" fc3_experts_scales must be 2D or 3D tensor" );
148
+ }
149
+ }
150
+ }
96
151
97
152
if (fc3_experts_weights == nullptr ) {
98
153
ORT_ENFORCE (fc3_experts_bias == nullptr && fc3_experts_scales == nullptr );
0 commit comments