Skip to content

Commit 0cf9883

Browse files
committed
Fix torch compile in cutlass.
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
1 parent 8e7ab78 commit 0cf9883

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -715,25 +715,15 @@ def forward_fake(
715715
use_dp_padding: Optional[bool] = None,
716716
**kwargs,
717717
) -> Union[torch.Tensor, List[torch.Tensor]]:
718-
if not self.enable_alltoall:
719-
return super().forward_fake(
720-
x,
721-
router_logits,
722-
do_finalize=do_finalize,
723-
output_dtype=output_dtype,
724-
all_rank_num_tokens=all_rank_num_tokens,
725-
use_dp_padding=use_dp_padding,
726-
**kwargs,
727-
)
728-
else:
729-
is_nvfp4_input = isinstance(x, Fp4QuantizedTensor)
730-
data_type = output_dtype if is_nvfp4_input else x.dtype
731-
num_tokens = all_rank_num_tokens[
732-
self.mapping.tp_rank] if all_rank_num_tokens else x.shape[0]
733-
hidden_size = x.shape[1] * (2 if is_nvfp4_input else 1)
734-
top_k = self.routing_method.experts_per_token
735-
return x.new_empty((num_tokens, top_k, hidden_size),
736-
dtype=data_type)
718+
return super().forward_fake(
719+
x,
720+
router_logits,
721+
do_finalize=do_finalize,
722+
output_dtype=output_dtype,
723+
all_rank_num_tokens=all_rank_num_tokens,
724+
use_dp_padding=use_dp_padding,
725+
**kwargs,
726+
)
737727

738728
def load_weights(self, weights: List[Dict]):
739729
assert self._weights_created

0 commit comments

Comments
 (0)