Skip to content

Commit 244f50a

Browse files
shixiancdivakar-amd
authored andcommitted
[Fix] enable swap_ab for pplx problem size computation (vllm-project#22991)
Signed-off-by: Shixian Cui <[email protected]> Co-authored-by: Shixian Cui <[email protected]>
1 parent 8f4c570 commit 244f50a

File tree

1 file changed

+32
-13
lines changed

1 file changed

+32
-13
lines changed

csrc/quantization/cutlass_w8a8/moe/moe_data.cu

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -161,21 +161,31 @@ void get_cutlass_moe_mm_data_caller(
161161
topk_ids.size(1));
162162
}
163163

164+
template <bool SWAP_AB>
164165
__global__ void compute_pplx_data(int32_t* expert_offsets,
165166
int32_t* problem_sizes1,
166167
int32_t* problem_sizes2,
167168
const int32_t* __restrict__ expert_num_tokens,
168169
const int padded_m, const int n,
169170
const int k) {
170171
int expert_idx = threadIdx.x;
171-
172172
expert_offsets[expert_idx] = expert_idx * padded_m;
173-
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
174-
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
175-
problem_sizes1[expert_idx * 3 + 2] = k;
176-
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
177-
problem_sizes2[expert_idx * 3 + 1] = k;
178-
problem_sizes2[expert_idx * 3 + 2] = n;
173+
174+
if constexpr (!SWAP_AB) {
175+
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
176+
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
177+
problem_sizes1[expert_idx * 3 + 2] = k;
178+
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
179+
problem_sizes2[expert_idx * 3 + 1] = k;
180+
problem_sizes2[expert_idx * 3 + 2] = n;
181+
} else {
182+
problem_sizes1[expert_idx * 3] = 2 * n;
183+
problem_sizes1[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
184+
problem_sizes1[expert_idx * 3 + 2] = k;
185+
problem_sizes2[expert_idx * 3] = k;
186+
problem_sizes2[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
187+
problem_sizes2[expert_idx * 3 + 2] = n;
188+
}
179189
}
180190

181191
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
@@ -187,10 +197,19 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
187197
const int64_t n, const int64_t k) {
188198
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
189199

190-
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
191-
static_cast<int32_t*>(expert_offsets.data_ptr()),
192-
static_cast<int32_t*>(problem_sizes1.data_ptr()),
193-
static_cast<int32_t*>(problem_sizes2.data_ptr()),
194-
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
195-
k);
200+
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
201+
compute_pplx_data<false><<<1, num_local_experts, 0, stream>>>(
202+
static_cast<int32_t*>(expert_offsets.data_ptr()),
203+
static_cast<int32_t*>(problem_sizes1.data_ptr()),
204+
static_cast<int32_t*>(problem_sizes2.data_ptr()),
205+
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
206+
k);
207+
} else {
208+
compute_pplx_data<true><<<1, num_local_experts, 0, stream>>>(
209+
static_cast<int32_t*>(expert_offsets.data_ptr()),
210+
static_cast<int32_t*>(problem_sizes1.data_ptr()),
211+
static_cast<int32_t*>(problem_sizes2.data_ptr()),
212+
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
213+
k);
214+
}
196215
}

0 commit comments

Comments
 (0)