Skip to content

Commit 6a5494b

Browse files
authored
Fix dynamic 0size nan complex problem (#73516)
1 parent e47977d commit 6a5494b

File tree

1 file changed

+15
-40
lines changed

1 file changed

+15
-40
lines changed

paddle/phi/kernels/gpu/moe_permute_kernel.cu

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ __global__ void tokens_unzip_stable_kernel(
5858
local_expert_offsets[i] = expert_base_offset.data[i];
5959
local_cumsum[i] = 0;
6060
}
61+
const int base_row_idx = blockIdx.x * CUMSUM_BLOCK_SIZE;
6162
__shared__ int shared_expert_rowmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
6263
__shared__ probs_T shared_expert_probmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
6364

@@ -238,13 +239,7 @@ void MoePermuteKernel(const Context &dev_ctx,
238239
"value.",
239240
MAX_NUM_EXPERTS,
240241
num_experts));
241-
if (X.numel() == 0) {
242-
dev_ctx.template Alloc<float>(XScale_unzipped);
243-
dev_ctx.template Alloc<int>(zipped_expertwise_rowmap);
244-
dev_ctx.template Alloc<T>(X_unzipped);
245-
dev_ctx.template Alloc<float>(token_prob_unzipped);
246-
return;
247-
}
242+
248243
const int quanted_cols = (XScale) ? XScale.get_ptr()->dims()[1] : 0;
249244
expert_base_offset expert_offset;
250245
int tokens_cumulated = 0;
@@ -270,46 +265,26 @@ void MoePermuteKernel(const Context &dev_ctx,
270265
dev_ctx.template Alloc<float>(XScale_unzipped);
271266
dev_ctx.template Alloc<int>(zipped_expertwise_rowmap);
272267
dev_ctx.template Alloc<T>(X_unzipped);
268+
dev_ctx.template Alloc<float>(token_prob_unzipped);
273269
auto X_unzipped_ptr = reinterpret_cast<void *>(X_unzipped->data<T>());
274-
for (int i = 0; i < num_experts; i++) {
275-
int next_expert_offset =
276-
i < num_experts - 1 ? expert_offset.data[i + 1] : output_rows;
277-
int invalid_rows =
278-
next_expert_offset - expert_offset.data[i] - tokens_per_expert[i];
279-
cudaMemsetAsync(X_unzipped_ptr + tokens_per_expert[i] * sizeof(T),
280-
0,
281-
sizeof(T) * invalid_rows * cols,
282-
dev_ctx.stream());
283-
}
270+
cudaMemsetAsync(
271+
X_unzipped_ptr, 0, sizeof(T) * output_rows * cols, dev_ctx.stream());
284272
if (XScale) {
285273
auto XScale_unzipped_ptr =
286274
reinterpret_cast<void *>(XScale_unzipped->data<float>());
287-
for (int i = 0; i < num_experts; i++) {
288-
int next_expert_offset =
289-
i < num_experts - 1 ? expert_offset.data[i + 1] : output_rows;
290-
int invalid_rows =
291-
next_expert_offset - expert_offset.data[i] - tokens_per_expert[i];
292-
cudaMemsetAsync(
293-
XScale_unzipped_ptr + tokens_per_expert[i] * sizeof(float),
294-
0,
295-
sizeof(float) * invalid_rows * quanted_cols,
296-
dev_ctx.stream());
297-
}
275+
cudaMemsetAsync(XScale_unzipped_ptr,
276+
0,
277+
sizeof(float) * output_rows * quanted_cols,
278+
dev_ctx.stream());
298279
}
299-
dev_ctx.template Alloc<float>(token_prob_unzipped);
280+
300281
auto token_prob_unzipped_ptr =
301282
reinterpret_cast<void *>(token_prob_unzipped->data<float>());
302-
for (int i = 0; i < num_experts; i++) {
303-
int next_expert_offset =
304-
i < num_experts - 1 ? expert_offset.data[i + 1] : output_rows;
305-
int invalid_rows =
306-
next_expert_offset - expert_offset.data[i] - tokens_per_expert[i];
307-
cudaMemsetAsync(
308-
token_prob_unzipped_ptr + tokens_per_expert[i] * sizeof(float),
309-
0,
310-
sizeof(float) * invalid_rows,
311-
dev_ctx.stream());
312-
}
283+
cudaMemsetAsync(token_prob_unzipped_ptr,
284+
0,
285+
sizeof(float) * output_rows,
286+
dev_ctx.stream());
287+
if (X.numel() == 0) return;
313288
const int cumsum_blocknum =
314289
(rows + CUMSUM_BLOCK_SIZE - 1) / CUMSUM_BLOCK_SIZE;
315290
DenseTensor global_expertwise_block_cumsum =

0 commit comments

Comments
 (0)