Skip to content

Commit a1a1d79

Browse files
committed
refined IR and elaborate comments
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
1 parent f55aa75 commit a1a1d79

File tree

3 files changed

+148
-11
lines changed

3 files changed

+148
-11
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""GLM4-MoE model patches for auto-deploy compatibility.
17+
18+
This module patches the GLM4-MoE model to make it compatible with torch.fx export
19+
by replacing data-dependent operations (torch.where/nonzero) with traceable custom ops.
20+
"""
21+
22+
import types
23+
from typing import Dict
24+
25+
import torch
26+
from transformers import AutoModelForCausalLM
27+
28+
29+
@torch.inference_mode()
30+
def glm4_moe_forward(self, hidden_states):
31+
"""Glm4MoeMoE forward function rewritten to enable torch export.
32+
33+
Replaces self.moe() call (which uses torch.where) with torch_moe custom op.
34+
"""
35+
residuals = hidden_states
36+
orig_shape = hidden_states.shape
37+
38+
# Gate directly returns (topk_indices, topk_weights)
39+
topk_indices, topk_weights = self.gate(hidden_states)
40+
41+
# Flatten for MoE processing
42+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
43+
44+
# Replace self.moe() with torch_moe custom op
45+
# self.experts is a ModuleList of Glm4MoeMLP, each with gate_proj, up_proj, down_proj
46+
# Collect weights from each expert
47+
w1_weight = [expert.gate_proj.weight for expert in self.experts] # gate_proj
48+
w2_weight = [expert.down_proj.weight for expert in self.experts] # down_proj
49+
w3_weight = [expert.up_proj.weight for expert in self.experts] # up_proj
50+
51+
hidden_states = torch.ops.auto_deploy.torch_moe(
52+
hidden_states,
53+
topk_indices,
54+
topk_weights,
55+
w1_weight=w1_weight,
56+
w2_weight=w2_weight,
57+
w3_weight=w3_weight,
58+
)
59+
60+
hidden_states = hidden_states.view(*orig_shape)
61+
62+
# Add shared experts output
63+
hidden_states = hidden_states + self.shared_experts(residuals)
64+
65+
return hidden_states
66+
67+
68+
# Store original from_config
69+
_from_config_original = AutoModelForCausalLM.from_config
70+
71+
# Module patches mapping
72+
CUSTOM_MODULE_PATCHES: Dict[str, callable] = {
73+
"Glm4MoeMoE": glm4_moe_forward,
74+
}
75+
76+
77+
def get_model_from_config_patched(config, **kwargs):
78+
"""Patched from_config that applies GLM4-MoE module patches."""
79+
model = _from_config_original(config, **kwargs)
80+
81+
# Patch modules
82+
for _, module in model.named_modules():
83+
if type(module).__name__ in CUSTOM_MODULE_PATCHES.keys():
84+
module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module)
85+
86+
return model
87+
88+
89+
# Apply the patch
90+
AutoModelForCausalLM.from_config = get_model_from_config_patched

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1912,8 +1912,18 @@ def _shard_intermediate_attention_weights(
19121912
Shard intermediate weights (e.g. q_norm, k_norm) for attention layers.
19131913
19141914
For attention layers, there may be intermediate weights (like q_norm.weight, k_norm.weight)
1915-
that operate element-wise on the sharded output of q_proj/k_proj. These need to be sharded
1916-
along the same dimension.
1915+
that operate directly on q_proj/k_proj output (before reshaping to [batch, seq, num_heads, head_dim]).
1916+
These need to be sharded along the same head dimension.
1917+
1918+
Example1: - Norm on all heads directly on flattened Q/K output [batch, seq, hidden_size] (e.g. MiniMax):
1919+
self.q_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_attention_heads, eps=config.rms_norm_eps)
1920+
weight shape: [num_heads * head_dim]
1921+
Status: Needs qk norm sharding. (will be handled in this function)
1922+
1923+
Example2: - Norm per head after reshaping to [batch, seq, num_heads, head_dim] (e.g. GLM 4.7):
1924+
self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
1925+
weight shape: [head_dim]
1926+
Status: No need to shard/will be skipped.
19171927
19181928
Args:
19191929
layer_subgraph: The attention layer subgraph

tensorrt_llm/_torch/auto_deploy/utils/logger.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,66 @@ def _get_dtype_or_type(val):
1616
return type(val).__name__
1717

1818

19+
def _get_shape_str(val):
20+
"""Get shape as 'dim0xdim1x...' string, or '?' if not available."""
21+
if hasattr(val, "shape"):
22+
# Handle symbolic dimensions (SymInt) by converting to str
23+
dims = [str(int(d)) if str(d).isdigit() else str(d) for d in val.shape]
24+
return "x".join(dims) if dims else "scalar"
25+
return "?"
26+
27+
28+
def _get_shape_dtype_str(val):
29+
"""Return 'shape : dtype' string for a value."""
30+
shape = _get_shape_str(val)
31+
dtype = _get_dtype_or_type(val)
32+
return f"{shape} : {dtype}"
33+
34+
1935
def dump_ssa_with_meta(f, mod):
2036
for node in mod.graph.nodes:
2137
# Write out IR in traditional SSA style
2238
if node.op == "placeholder":
2339
if "val" in node.meta:
24-
dtype = _get_dtype_or_type(node.meta["val"])
40+
shape_dtype = _get_shape_dtype_str(node.meta["val"])
2541
else:
26-
dtype = "unknown"
27-
f.write(f"%{node.name} : {dtype}\n")
42+
shape_dtype = "? : unknown"
43+
f.write(f"%{node.name} : {shape_dtype}\n")
2844
elif node.op in ("call_function", "call_method", "call_module"):
29-
# Build inputs list in SSA format
45+
# Build inputs list in SSA format with shape:dtype info
3046
input_vars = []
3147
for arg in node.args:
3248
if hasattr(arg, "name"):
33-
input_vars.append(f"%{arg.name}")
49+
# Look up the arg node's metadata for shape/dtype
50+
if hasattr(arg, "meta") and "val" in arg.meta:
51+
arg_shape_dtype = _get_shape_dtype_str(arg.meta["val"])
52+
input_vars.append(f"%{arg.name} : {arg_shape_dtype}")
53+
else:
54+
input_vars.append(f"%{arg.name} : ? : unknown")
3455
else:
3556
input_vars.append(str(arg))
57+
58+
# Handle output shape/dtype (including multi-output)
3659
if "val" in node.meta:
37-
out_dtype = _get_dtype_or_type(node.meta["val"])
60+
out_val = node.meta["val"]
61+
if isinstance(out_val, (tuple, list)):
62+
# Multi-output: (shape1, shape2) : (dtype1, dtype2)
63+
shapes = []
64+
dtypes = []
65+
for v in out_val:
66+
if v is not None:
67+
shapes.append(_get_shape_str(v))
68+
dtypes.append(str(_get_dtype_or_type(v)))
69+
else:
70+
shapes.append("?")
71+
dtypes.append("None")
72+
out_info = f"({', '.join(shapes)}) : ({', '.join(dtypes)})"
73+
else:
74+
out_info = _get_shape_dtype_str(out_val)
3875
else:
39-
out_dtype = "N/A"
40-
# Standard SSA notation: %out = op(args) : out_dtype
41-
f.write(f"%{node.name} = {node.target}({', '.join(input_vars)}) : {out_dtype}\n")
76+
out_info = "? : N/A"
77+
# Standard SSA notation: %out = op(args) : shape : dtype
78+
f.write(f"%{node.name} = {node.target}({', '.join(input_vars)}) : {out_info}\n")
4279
elif node.op == "output":
4380
# Output assignment in SSA IR
4481
outputs = node.args[0] if isinstance(node.args[0], (tuple, list)) else [node.args[0]]

0 commit comments

Comments
 (0)