Skip to content

Commit 8282fc4

Browse files
committed
lm_head weight tying fix
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
1 parent 2d96941 commit 8282fc4

File tree

5 files changed

+550
-0
lines changed

5 files changed

+550
-0
lines changed

.cursorignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv)
2+
3+
*cubin.cpp
4+
*cubin.h
5+
3rdparty/
6+
data/
7+
tensorrt_llm/_torch/models/

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ transforms:
100100
############################################################################################
101101
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
102102
############################################################################################
103+
sync_tied_weights:
104+
stage: post_load_fusion
105+
run_per_gm: false
103106
fuse_gemms:
104107
stage: post_load_fusion
105108
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
"""Transform to sync tied weights after submodule export and weight loading.
2+
3+
When a submodule is exported to a GraphModule, weight tying between parameters
4+
inside and outside the exported submodule can break. This transform restores
5+
the tying by making non-exported parameters reference the exported parameters'
6+
tensors.
7+
8+
This transform runs AFTER weights are loaded (stage: post_load_fusion) so it can
9+
directly sync the already-loaded weights.
10+
11+
This is particularly important for VLM models like Gemma3 where:
12+
- embed_tokens.weight is inside the exported language_model
13+
- lm_head.weight is outside (at parent level)
14+
- They share the same weight via _tied_weights_keys
15+
"""
16+
17+
from typing import List, Set, Tuple, Type
18+
19+
import torch
20+
import torch.nn as nn
21+
22+
from ...models.factory import ModelFactory
23+
from ...shim.interface import CachedSequenceInterface
24+
from ...utils.logger import ad_logger
25+
from ..interface import (
26+
BaseTransform,
27+
SharedConfig,
28+
TransformConfig,
29+
TransformInfo,
30+
TransformRegistry,
31+
)
32+
33+
34+
def _get_tied_weight_pairs(mod: nn.Module) -> List[Tuple[str, str]]:
35+
"""Extract tied weight pairs from model's _tied_weights_keys attribute.
36+
37+
HF models can declare tied weights in multiple formats:
38+
1. Dict format: {"lm_head.weight": "model.embed_tokens.weight"} - explicit dst->src mapping
39+
2. List format: ["lm_head.weight"] - just lists the tied key, src is from get_input_embeddings()
40+
41+
For list format, we use get_input_embeddings() and get_output_embeddings() to determine
42+
the actual tying relationship.
43+
44+
Args:
45+
mod: The model to extract tied weight pairs from.
46+
47+
Returns:
48+
List of (dst_key, src_key) tuples where dst is tied TO src.
49+
Returns empty list if no tied weights are declared.
50+
"""
51+
tied_keys = getattr(mod, "_tied_weights_keys", None)
52+
if not tied_keys:
53+
return []
54+
55+
# Dict format: explicit mapping {"dst": "src"}
56+
if isinstance(tied_keys, dict):
57+
return list(tied_keys.items())
58+
59+
# List/set format: this typically means word embeddings are tied
60+
# Check config.tie_word_embeddings (HF's standard flag) to confirm
61+
if isinstance(tied_keys, (list, tuple, set)):
62+
# Check if tie_word_embeddings is enabled (HF's standard config flag)
63+
config = getattr(mod, "config", None)
64+
tie_word_embeddings = getattr(config, "tie_word_embeddings", None)
65+
66+
# Also check text_config for VLM models
67+
if tie_word_embeddings is None and config is not None:
68+
text_config = getattr(config, "text_config", None)
69+
if text_config is not None:
70+
tie_word_embeddings = getattr(text_config, "tie_word_embeddings", None)
71+
72+
if not tie_word_embeddings:
73+
ad_logger.debug(
74+
f"_tied_weights_keys={tied_keys} but tie_word_embeddings is not True, skipping"
75+
)
76+
return []
77+
78+
# tie_word_embeddings=True and we have a list like ["lm_head.weight"]
79+
# Use HF's standard methods to find the actual tied modules
80+
input_embeddings = None
81+
output_embeddings = None
82+
input_embed_key = None
83+
output_embed_key = None
84+
85+
try:
86+
if hasattr(mod, "get_input_embeddings"):
87+
input_embeddings = mod.get_input_embeddings()
88+
if hasattr(mod, "get_output_embeddings"):
89+
output_embeddings = mod.get_output_embeddings()
90+
except Exception:
91+
pass
92+
if input_embeddings is None or output_embeddings is None:
93+
ad_logger.warning(
94+
f"tie_word_embeddings=True but get_input_embeddings/get_output_embeddings "
95+
f"returned None (input={input_embeddings}, output={output_embeddings})"
96+
)
97+
return []
98+
99+
# Find the parameter paths for input and output embeddings
100+
for name, submod in mod.named_modules():
101+
if submod is input_embeddings:
102+
input_embed_key = f"{name}.weight" if name else "weight"
103+
if submod is output_embeddings:
104+
output_embed_key = f"{name}.weight" if name else "weight"
105+
if input_embed_key and output_embed_key and input_embed_key != output_embed_key:
106+
# output (lm_head) is tied TO input (embed_tokens)
107+
ad_logger.debug(
108+
f"Inferred tied weight pair: {output_embed_key} -> {input_embed_key} "
109+
f"(tie_word_embeddings=True)"
110+
)
111+
return [(output_embed_key, input_embed_key)]
112+
113+
ad_logger.warning(
114+
f"tie_word_embeddings=True but could not find embedding paths: "
115+
f"input={input_embed_key}, output={output_embed_key}"
116+
)
117+
return []
118+
119+
return []
120+
121+
122+
def _get_exported_submodule_keys(mod: nn.Module) -> List[str]:
123+
"""Infer which submodules were exported by detecting GraphModules.
124+
125+
Args:
126+
mod: The root model to search for exported submodules.
127+
128+
Returns:
129+
List of submodule key paths that are GraphModules (i.e., were exported).
130+
"""
131+
exported_keys = []
132+
for name, submod in mod.named_modules():
133+
if isinstance(submod, torch.fx.GraphModule):
134+
exported_keys.append(name)
135+
return exported_keys
136+
137+
138+
def _detect_cross_boundary_tied_weights(
139+
mod: nn.Module,
140+
exported_submodule_keys: List[str],
141+
) -> Tuple[List[Tuple[str, str]], Set[str]]:
142+
"""Detect tied weights that cross the export boundary.
143+
144+
When a submodule is exported, weight tying between parameters inside and outside
145+
the exported submodule can break. This function identifies such cross-boundary pairs.
146+
147+
The exported parameter becomes the canonical source of truth because it's embedded
148+
in the GraphModule's graph (via get_attr nodes) and cannot be easily changed.
149+
150+
Args:
151+
mod: The root model containing both exported and non-exported submodules.
152+
exported_submodule_keys: List of submodule key paths that were exported.
153+
154+
Returns:
155+
Tuple of:
156+
- List of (dst_key, src_key) pairs that have cross-boundary tying
157+
- Set of canonical keys (the exported ones that are sources of truth)
158+
"""
159+
tied_pairs = _get_tied_weight_pairs(mod)
160+
if not tied_pairs:
161+
return [], set()
162+
163+
def is_in_exported(key: str) -> bool:
164+
"""Check if parameter key is inside an exported submodule."""
165+
for sub in exported_submodule_keys:
166+
if sub == "": # Full model exported (root is GraphModule)
167+
return True
168+
if key.startswith(f"{sub}."):
169+
return True
170+
return False
171+
172+
cross_boundary_pairs = []
173+
canonical_keys = set()
174+
for dst_key, src_key in tied_pairs:
175+
src_exported = is_in_exported(src_key)
176+
dst_exported = is_in_exported(dst_key)
177+
178+
if src_exported == dst_exported:
179+
# Both exported or both not exported - no cross-boundary issue
180+
# Existing deduplication handles both-exported case
181+
continue
182+
183+
# Cross-boundary case: one exported, one not
184+
cross_boundary_pairs.append((dst_key, src_key))
185+
186+
# Determine which is canonical (exported)
187+
if src_exported:
188+
canonical_keys.add(src_key)
189+
else:
190+
canonical_keys.add(dst_key)
191+
192+
return cross_boundary_pairs, canonical_keys
193+
194+
195+
def _sync_tied_weights(
196+
mod: nn.Module,
197+
cross_boundary_pairs: List[Tuple[str, str]],
198+
canonical_keys: Set[str],
199+
) -> int:
200+
"""Sync tied weights by making non-canonical weights point to canonical weights.
201+
202+
This function should be called AFTER weights are loaded. It makes the non-exported
203+
weight (e.g., lm_head.weight) point to the same tensor as the exported weight
204+
(e.g., embed_tokens.weight).
205+
206+
Args:
207+
mod: The root model with loaded weights.
208+
cross_boundary_pairs: List of (dst_key, src_key) pairs with cross-boundary tying.
209+
canonical_keys: Set of parameter keys that are canonical (exported).
210+
211+
Returns:
212+
Number of weights successfully synced.
213+
"""
214+
synced_count = 0
215+
for dst_key, src_key in cross_boundary_pairs:
216+
# Determine canonical vs redirect keys
217+
if src_key in canonical_keys:
218+
canonical_key = src_key
219+
redirect_key = dst_key
220+
else:
221+
canonical_key = dst_key
222+
redirect_key = src_key
223+
224+
try:
225+
# Get the loaded canonical parameter
226+
canonical_param = mod.get_parameter(canonical_key)
227+
228+
# Parse redirect key into module path and param name
229+
parts = redirect_key.rsplit(".", 1)
230+
if len(parts) > 1:
231+
redirect_mod = mod.get_submodule(parts[0])
232+
redirect_name = parts[1]
233+
else:
234+
redirect_mod = mod
235+
redirect_name = parts[0]
236+
237+
# Remove from _parameters so it's not a registered parameter
238+
# (prevents double-counting in state_dict, optimizer, etc.)
239+
if redirect_name in redirect_mod._parameters:
240+
del redirect_mod._parameters[redirect_name]
241+
242+
# Sync: make redirect point to the canonical tensor
243+
setattr(redirect_mod, redirect_name, canonical_param)
244+
ad_logger.info(f"Synced tied weight: {redirect_key} -> {canonical_key} (canonical)")
245+
synced_count += 1
246+
except Exception as e:
247+
ad_logger.warning(f"Failed to sync tied weight {redirect_key} -> {canonical_key}: {e}")
248+
249+
return synced_count
250+
251+
252+
class SyncTiedWeightsConfig(TransformConfig):
253+
"""Configuration for the sync tied weights transform."""
254+
255+
pass # No configuration options needed for now
256+
257+
258+
@TransformRegistry.register("sync_tied_weights")
259+
class SyncTiedWeights(BaseTransform):
260+
"""Sync tied weights that cross the export boundary.
261+
262+
This transform runs AFTER weights are loaded (stage: post_load_fusion).
263+
It detects GraphModules to infer which submodules were exported, then
264+
syncs any tied weights that cross the export boundary.
265+
266+
For example, in Gemma3 VLM:
267+
- language_model is exported to GraphModule (contains embed_tokens.weight)
268+
- lm_head is at parent level (not exported)
269+
- _tied_weights_keys declares lm_head.weight -> embed_tokens.weight
270+
- This transform makes lm_head.weight reference embed_tokens.weight
271+
"""
272+
273+
config: SyncTiedWeightsConfig
274+
275+
@classmethod
276+
def get_config_class(cls) -> Type[TransformConfig]:
277+
return SyncTiedWeightsConfig
278+
279+
def _apply_to_full_model(
280+
self,
281+
mod: nn.Module,
282+
cm: CachedSequenceInterface,
283+
factory: ModelFactory,
284+
shared_config: SharedConfig,
285+
) -> Tuple[nn.Module, TransformInfo]:
286+
# Infer exported submodules by detecting GraphModules
287+
exported_keys = _get_exported_submodule_keys(mod)
288+
if not exported_keys:
289+
# No GraphModules found - nothing to sync
290+
return mod, TransformInfo(
291+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
292+
)
293+
294+
# Detect cross-boundary tied weights
295+
cross_boundary_pairs, canonical_keys = _detect_cross_boundary_tied_weights(
296+
mod, exported_keys
297+
)
298+
299+
if not cross_boundary_pairs:
300+
return mod, TransformInfo(
301+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
302+
)
303+
304+
# Directly sync the weights (weights are already loaded at this point)
305+
synced_count = _sync_tied_weights(mod, cross_boundary_pairs, canonical_keys)
306+
307+
return mod, TransformInfo(
308+
skipped=False,
309+
num_matches=synced_count,
310+
is_clean=True,
311+
has_valid_shapes=True,
312+
)

0 commit comments

Comments
 (0)