Skip to content

Commit 4c5a8f4

Browse files
authored
[None][fix] Rename: slot_count -> invalid_expert_id (#8783)
Signed-off-by: Bo Li <[email protected]>
1 parent 89e0117 commit 4c5a8f4

File tree

5 files changed

+10
-10
lines changed

5 files changed

+10
-10
lines changed

cpp/tensorrt_llm/kernels/moePrepareKernels.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ __global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum
280280
}
281281

282282
__global__ void memsetExpertIdsDevice(
283-
int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount, int rankCount)
283+
int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int invalidExpertId, int rankCount)
284284
{
285285
int maxTokenCount = maxTokenCountPerRank * rankCount;
286286
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
@@ -291,7 +291,7 @@ __global__ void memsetExpertIdsDevice(
291291
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i + totalRecvTokenCount * topK < maxTokenCount * topK;
292292
i += gridDim.x * blockDim.x)
293293
{
294-
*(expertIds + i + totalRecvTokenCount * topK) = slotCount;
294+
*(expertIds + i + totalRecvTokenCount * topK) = invalidExpertId;
295295
}
296296
}
297297

@@ -355,7 +355,7 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i
355355
maxTokenCountPerRank);
356356
}
357357

358-
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount,
358+
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int invalidExpertId,
359359
int rankCount, cudaStream_t stream)
360360
{
361361
int smCount = tensorrt_llm::common::getMultiProcessorCount();
@@ -364,7 +364,7 @@ void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPer
364364
dim3 grid(smCount);
365365

366366
launchWithPdlWhenEnabled("memsetExpertIds", memsetExpertIdsDevice, grid, block, 0, stream, expertIds,
367-
recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
367+
recvCountsCumsum, maxTokenCountPerRank, topK, invalidExpertId, rankCount);
368368
}
369369

370370
size_t getMoePrepareWorkspaceSize(int epSize)

cpp/tensorrt_llm/kernels/moePrepareKernels.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i
8080
int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount,
8181
int maxTokenCountPerRank, cudaStream_t stream);
8282

83-
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount,
83+
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int invalidExpertId,
8484
int epSize, cudaStream_t stream);
8585

8686
size_t getMoePrepareWorkspaceSize(int epSize);

cpp/tensorrt_llm/thop/moeCommOp.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> expertsStati
228228
}
229229

230230
void memsetExpertIds(torch::Tensor expertsIds, torch::Tensor recvRankCountCumSum, int64_t maxTokenCountPerRank,
231-
int64_t topK, int64_t slotCount, int64_t epSize)
231+
int64_t topK, int64_t invalidExpertId, int64_t epSize)
232232
{
233233
CHECK_INPUT(expertsIds, torch::kInt32);
234234
TORCH_CHECK(expertsIds.dim() == 2, "expertsIds must be a 1D tensor");
@@ -243,7 +243,7 @@ void memsetExpertIds(torch::Tensor expertsIds, torch::Tensor recvRankCountCumSum
243243
auto stream = at::cuda::getCurrentCUDAStream();
244244

245245
tensorrt_llm::kernels::moe_prepare::memsetExpertIds(expertsIds.data_ptr<int>(), recvRankCountCumSum.data_ptr<int>(),
246-
static_cast<int>(maxTokenCountPerRank), static_cast<int>(topK), static_cast<int>(slotCount),
246+
static_cast<int>(maxTokenCountPerRank), static_cast<int>(topK), static_cast<int>(invalidExpertId),
247247
static_cast<int>(epSize), stream);
248248
}
249249

@@ -310,7 +310,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
310310
m.def(
311311
"memset_expert_ids(Tensor(a!) experts_ids, Tensor recv_rank_count_cumsum, int max_token_count_per_rank, int "
312312
"top_k, "
313-
"int slot_count, int ep_size) -> ()");
313+
"int invalid_expert_id, int ep_size) -> ()");
314314
}
315315

316316
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def _(single_layer_load_balancer_ptr: int,
283283

284284
@torch.library.register_fake("trtllm::memset_expert_ids")
285285
def _(experts_ids: torch.Tensor, recv_rank_count_cumsum: torch.Tensor,
286-
max_token_count_per_rank: int, top_k: int, slot_count: int,
286+
max_token_count_per_rank: int, top_k: int, invalid_expert_id: int,
287287
ep_size: int):
288288
pass
289289

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def forward_impl(
370370
alltoall_info.recv_rank_count_cumsum,
371371
max_num_token,
372372
top_k,
373-
self.num_slots,
373+
-1, # Trtllm Gen uses -1 as invalid expert id
374374
self.ep_size,
375375
)
376376

0 commit comments

Comments
 (0)