Skip to content

Commit c235f0d

Browse files
committed
fix compilation
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent e2ec49a commit c235f0d

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor c
255255
return output;
256256
}
257257

258-
torch::Tensor moe_output_memset(torch::Tensor const& input, torch::Tensor const& tile_idx_to_mn_limit,
258+
void moe_output_memset_inplace(torch::Tensor const& input, torch::Tensor const& tile_idx_to_mn_limit,
259259
torch::Tensor const& expanded_idx_to_permuted_idx, torch::Tensor const& permuted_idx_to_expanded_idx,
260260
torch::Tensor const& num_non_exiting_tiles, int64_t const tile_tokens_dim, int64_t const top_k)
261261
{
@@ -305,8 +305,6 @@ torch::Tensor moe_output_memset(torch::Tensor const& input, torch::Tensor const&
305305
}
306306

307307
#undef DISPATCH_MOE_OUTPUT_MEMSET
308-
309-
return input;
310308
}
311309

312310
// Activation
@@ -478,8 +476,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
478476
"Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)");
479477
m.def("moe_unpermute(Tensor permuted_input, Tensor expanded_idx_to_permuted_idx, Tensor topk_scales) -> Tensor");
480478
m.def(
481-
"moe_output_memset(Tensor! input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, "
482-
"Tensor permuted_idx_to_expanded_idx, Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> Tensor");
479+
"moe_output_memset_inplace(Tensor(a!) input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, "
480+
"Tensor permuted_idx_to_expanded_idx, Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> ()");
483481
m.def(
484482
"moe_swiglu(Tensor input, Tensor tile_idx_to_mn_limit, Tensor num_non_exiting_tiles, "
485483
"int tile_tokens_dim) -> Tensor");
@@ -497,7 +495,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
497495
m.impl("moe_sort", &torch_ext::moe_sort);
498496
m.impl("moe_permute", &torch_ext::moe_permute);
499497
m.impl("moe_unpermute", &torch_ext::moe_unpermute);
500-
m.impl("moe_output_memset", &torch_ext::moe_output_memset);
498+
m.impl("moe_output_memset_inplace", &torch_ext::moe_output_memset_inplace);
501499
m.impl("moe_swiglu", &torch_ext::moe_swiglu);
502500
m.impl("moe_swiglu_nvfp4_quantize", &torch_ext::moe_swiglu_nvfp4_quantize);
503501
m.impl("moe_gelu", &torch_ext::moe_gelu);

tensorrt_llm/_torch/compilation/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def inplace_info():
7676
},
7777
torch.ops.trtllm.logits_bitmask.default: {
7878
1: "logits"
79+
},
80+
torch.ops.trtllm.moe_output_memset_inplace.default: {
81+
1: "input"
82+
},
83+
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell.default:
84+
{
85+
6: "output"
7986
}
8087
}
8188
return inplace_map

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def run_moe_nvfp4(
302302
output = torch.empty(output_shape,
303303
dtype=output_dtype,
304304
device=x.device)
305-
torch.ops.trtllm.moe_output_memset(
305+
torch.ops.trtllm.moe_output_memset_inplace(
306306
output=output,
307307
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
308308
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,

tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def test_moe_unpermute(dtype: str, num_tokens: int, top_k: int, tile_size: int):
231231
@pytest.mark.parametrize("top_k", [1, 2, 8])
232232
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
233233
@pytest.mark.parametrize("dtype", ["bfloat16", "float16"])
234-
def test_moe_output_memset(dtype: str, num_tokens: int, top_k: int, tile_size: int):
234+
def test_moe_output_memset_inplace(dtype: str, num_tokens: int, top_k: int, tile_size: int):
235235
dtype = getattr(torch, dtype)
236236
hidden_size = 4096
237237
num_experts = 256
@@ -260,7 +260,7 @@ def test_moe_output_memset(dtype: str, num_tokens: int, top_k: int, tile_size: i
260260
)
261261

262262
x = torch.ones(num_tokens, hidden_size, dtype=dtype, device="cuda")
263-
x = torch.ops.trtllm.moe_output_memset(
263+
torch.ops.trtllm.moe_output_memset_inplace(
264264
x,
265265
tile_idx_to_mn_limit,
266266
expanded_idx_to_permuted_idx,

0 commit comments

Comments
 (0)