Skip to content

Commit e07fff4

Browse files
authored
[https://nvbugs/5340941] - fix: Correct custom ops used by Qwen3 Moe … (NVIDIA#6285)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
1 parent 6268a60 commit e07fff4

File tree

4 files changed

+14
-4
lines changed

4 files changed

+14
-4
lines changed

cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,8 @@ void fused_qk_norm_rope(
7575
TORCH_LIBRARY_FRAGMENT(trtllm, m)
7676
{
7777
m.def(
78-
"fused_qk_norm_rope(Tensor qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float eps, "
79-
"Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids) -> ()",
80-
&fused_qk_norm_rope);
78+
"fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float "
79+
"eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids) -> ()");
8180
}
8281

8382
// Register the CUDA implementation

cpp/tensorrt_llm/thop/renormMoeRoutingOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ std::tuple<at::Tensor, at::Tensor> renorm_moe_routing_op(th::Tensor const& route
7474
TORCH_LIBRARY_FRAGMENT(trtllm, m)
7575
{
7676
m.def(
77-
"renorm_moe_routing_op(Tensor router_logits, int topk"
77+
"renorm_moe_routing_op(Tensor router_logits, SymInt topk"
7878
") -> (Tensor, Tensor)");
7979
}
8080

tensorrt_llm/_torch/compilation/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def inplace_info():
5555
},
5656
torch.ops.trtllm.mla_custom_op_inplace.default: {
5757
1: "output"
58+
},
59+
torch.ops.trtllm.fused_qk_norm_rope.default: {
60+
1: "qkv"
5861
}
5962
}
6063
return inplace_map

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,3 +523,11 @@ def _(input, residual, norm_weight, expanded_idx_to_permuted_idx,
523523
torch.empty_like(residual),
524524
torch.empty_like(residual),
525525
]
526+
527+
@torch.library.register_fake("trtllm::renorm_moe_routing_op")
528+
def _(router_logits, topk):
529+
num_tokens = router_logits.shape[0]
530+
sz = (num_tokens, topk)
531+
return router_logits.new_empty(
532+
sz, dtype=torch.int32), router_logits.new_empty(sz,
533+
dtype=torch.float32)

0 commit comments

Comments
 (0)