@@ -161,21 +161,31 @@ void get_cutlass_moe_mm_data_caller(
161
161
topk_ids.size (1 ));
162
162
}
163
163
164
+ template <bool SWAP_AB>
164
165
__global__ void compute_pplx_data (int32_t * expert_offsets,
165
166
int32_t * problem_sizes1,
166
167
int32_t * problem_sizes2,
167
168
const int32_t * __restrict__ expert_num_tokens,
168
169
const int padded_m, const int n,
169
170
const int k) {
170
171
int expert_idx = threadIdx .x ;
171
-
172
172
expert_offsets[expert_idx] = expert_idx * padded_m;
173
- problem_sizes1[expert_idx * 3 ] = expert_num_tokens[expert_idx];
174
- problem_sizes1[expert_idx * 3 + 1 ] = 2 * n;
175
- problem_sizes1[expert_idx * 3 + 2 ] = k;
176
- problem_sizes2[expert_idx * 3 ] = expert_num_tokens[expert_idx];
177
- problem_sizes2[expert_idx * 3 + 1 ] = k;
178
- problem_sizes2[expert_idx * 3 + 2 ] = n;
173
+
174
+ if constexpr (!SWAP_AB) {
175
+ problem_sizes1[expert_idx * 3 ] = expert_num_tokens[expert_idx];
176
+ problem_sizes1[expert_idx * 3 + 1 ] = 2 * n;
177
+ problem_sizes1[expert_idx * 3 + 2 ] = k;
178
+ problem_sizes2[expert_idx * 3 ] = expert_num_tokens[expert_idx];
179
+ problem_sizes2[expert_idx * 3 + 1 ] = k;
180
+ problem_sizes2[expert_idx * 3 + 2 ] = n;
181
+ } else {
182
+ problem_sizes1[expert_idx * 3 ] = 2 * n;
183
+ problem_sizes1[expert_idx * 3 + 1 ] = expert_num_tokens[expert_idx];
184
+ problem_sizes1[expert_idx * 3 + 2 ] = k;
185
+ problem_sizes2[expert_idx * 3 ] = k;
186
+ problem_sizes2[expert_idx * 3 + 1 ] = expert_num_tokens[expert_idx];
187
+ problem_sizes2[expert_idx * 3 + 2 ] = n;
188
+ }
179
189
}
180
190
181
191
void get_cutlass_pplx_moe_mm_data_caller (torch::Tensor& expert_offsets,
@@ -187,10 +197,19 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
187
197
const int64_t n, const int64_t k) {
188
198
auto stream = at::cuda::getCurrentCUDAStream (expert_offsets.device ().index ());
189
199
190
- compute_pplx_data<<<1 , num_local_experts, 0 , stream>>> (
191
- static_cast <int32_t *>(expert_offsets.data_ptr ()),
192
- static_cast <int32_t *>(problem_sizes1.data_ptr ()),
193
- static_cast <int32_t *>(problem_sizes2.data_ptr ()),
194
- static_cast <const int32_t *>(expert_num_tokens.data_ptr ()), padded_m, n,
195
- k);
200
+ if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
201
+ compute_pplx_data<false ><<<1 , num_local_experts, 0 , stream>>> (
202
+ static_cast <int32_t *>(expert_offsets.data_ptr ()),
203
+ static_cast <int32_t *>(problem_sizes1.data_ptr ()),
204
+ static_cast <int32_t *>(problem_sizes2.data_ptr ()),
205
+ static_cast <const int32_t *>(expert_num_tokens.data_ptr ()), padded_m, n,
206
+ k);
207
+ } else {
208
+ compute_pplx_data<true ><<<1 , num_local_experts, 0 , stream>>> (
209
+ static_cast <int32_t *>(expert_offsets.data_ptr ()),
210
+ static_cast <int32_t *>(problem_sizes1.data_ptr ()),
211
+ static_cast <int32_t *>(problem_sizes2.data_ptr ()),
212
+ static_cast <const int32_t *>(expert_num_tokens.data_ptr ()), padded_m, n,
213
+ k);
214
+ }
196
215
}
0 commit comments