@@ -255,7 +255,7 @@ torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor c
255255 return output;
256256}
257257
258- torch::Tensor moe_output_memset (torch::Tensor const & input, torch::Tensor const & tile_idx_to_mn_limit,
258+ void moe_output_memset_inplace (torch::Tensor const & input, torch::Tensor const & tile_idx_to_mn_limit,
259259 torch::Tensor const & expanded_idx_to_permuted_idx, torch::Tensor const & permuted_idx_to_expanded_idx,
260260 torch::Tensor const & num_non_exiting_tiles, int64_t const tile_tokens_dim, int64_t const top_k)
261261{
@@ -305,8 +305,6 @@ torch::Tensor moe_output_memset(torch::Tensor const& input, torch::Tensor const&
305305 }
306306
307307#undef DISPATCH_MOE_OUTPUT_MEMSET
308-
309- return input;
310308}
311309
312310// Activation
@@ -478,8 +476,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
478476 " Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)" );
479477 m.def (" moe_unpermute(Tensor permuted_input, Tensor expanded_idx_to_permuted_idx, Tensor topk_scales) -> Tensor" );
480478 m.def (
481- " moe_output_memset (Tensor! input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, "
482- " Tensor permuted_idx_to_expanded_idx, Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> Tensor " );
479+ " moe_output_memset_inplace (Tensor(a!) input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, "
480+ " Tensor permuted_idx_to_expanded_idx, Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> () " );
483481 m.def (
484482 " moe_swiglu(Tensor input, Tensor tile_idx_to_mn_limit, Tensor num_non_exiting_tiles, "
485483 " int tile_tokens_dim) -> Tensor" );
@@ -497,7 +495,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
497495 m.impl (" moe_sort" , &torch_ext::moe_sort);
498496 m.impl (" moe_permute" , &torch_ext::moe_permute);
499497 m.impl (" moe_unpermute" , &torch_ext::moe_unpermute);
500- m.impl (" moe_output_memset " , &torch_ext::moe_output_memset );
498+ m.impl (" moe_output_memset_inplace " , &torch_ext::moe_output_memset_inplace );
501499 m.impl (" moe_swiglu" , &torch_ext::moe_swiglu);
502500 m.impl (" moe_swiglu_nvfp4_quantize" , &torch_ext::moe_swiglu_nvfp4_quantize);
503501 m.impl (" moe_gelu" , &torch_ext::moe_gelu);
0 commit comments