Skip to content

Commit 2ae080c

Browse files
committed
[https://nvbugs/5753788][chore] fix empty tensor cutlass moe
Signed-off-by: leslie-fang25 <[email protected]>
1 parent 910a633 commit 2ae080c

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

cpp/tensorrt_llm/kernels/quantization.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
146146
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
147147
// The number of blocks for m. The m dimension will be padded to 128 for swizzled layout.
148148
int numBlocksForM = layout == QuantizationSFLayout::SWIZZLED ? PadUpFn(m, 128) : m;
149-
dim3 grid(std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM));
149+
int gridSize = std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM);
150+
// Ensure gridSize is not zero.
151+
gridSize = std::max(1, gridSize);
152+
dim3 grid(gridSize);
150153

151154
// Launch the cvt kernel.
152155
auto* kernel_instance = useUE8M0
@@ -165,7 +168,10 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
165168
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
166169
// The number of blocks for m. The m dimension will be padded to 128 for swizzled layout.
167170
int numBlocksForM = layout == QuantizationSFLayout::SWIZZLED ? PadUpFn(m, 128) : m;
168-
dim3 grid(std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM));
171+
int gridSize = std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM);
172+
// Ensure gridSize is not zero.
173+
gridSize = std::max(1, gridSize);
174+
dim3 grid(gridSize);
169175

170176
// Launch the cvt kernel.
171177
auto* kernel_instance = useUE8M0

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,17 @@ def quantize_input(
319319
x_row = x.shape[0]
320320
else:
321321
x_row = x.shape[0]
322+
hidden_size = x.shape[-1]
322323
x, x_sf = torch.ops.trtllm.fp4_quantize(
323324
x, self.fc31_input_scale, self.scaling_vector_size,
324325
False, False)
326+
if x_sf.numel() == 0 and x_sf.dim() == 1:
327+
# View torch.Size[0] in to (0, -1) is not supported
328+
x_sf = x_sf.view(
329+
(0,
330+
hidden_size // int(self.scaling_vector_size)))
325331
# Reshape x_sf to 2D for post-quant communication
326-
if x_sf is not None:
332+
if x_sf is not None and x_sf.numel() != 0:
327333
x_sf = x_sf.view((x_row, -1))
328334
else:
329335
if not isinstance(x, Fp4QuantizedTensor):
@@ -494,8 +500,20 @@ def forward_chunk(
494500
self._load_balancer_start_wait_gpu_stage(is_first_call)
495501

496502
# apply routing
497-
token_selected_experts, token_final_scales = self.routing_method.apply(
498-
router_logits)
503+
if router_logits.numel() == 0:
504+
# For dtype, refer to https://github.com/NVIDIA/TensorRT-LLM/blob/55f3cda66d05a2e5686c9c7512721beb522bc8b7/tensorrt_llm/_torch/modules/fused_moe/routing.py#L327
505+
token_selected_experts = torch.empty(
506+
(0, self.routing_method.experts_per_token),
507+
dtype=torch.int32,
508+
device=router_logits.device)
509+
token_final_scales = torch.empty(
510+
(0, self.routing_method.experts_per_token),
511+
dtype=torch.float32,
512+
device=router_logits.device)
513+
else:
514+
token_selected_experts, token_final_scales = self.routing_method.apply(
515+
router_logits)
516+
499517
assert token_selected_experts.shape[
500518
1] == self.routing_method.experts_per_token
501519
assert token_selected_experts.shape == token_final_scales.shape

0 commit comments

Comments
 (0)