|
| 1 | +"""Transform to sync tied weights after submodule export and weight loading. |
| 2 | +
|
| 3 | +When a submodule is exported to a GraphModule, weight tying between parameters |
| 4 | +inside and outside the exported submodule can break. This transform restores |
| 5 | +the tying by making non-exported parameters reference the exported parameters' |
| 6 | +tensors. |
| 7 | +
|
| 8 | +This transform runs AFTER weights are loaded (stage: post_load_fusion) so it can |
| 9 | +directly sync the already-loaded weights. |
| 10 | +
|
| 11 | +This is particularly important for VLM models like Gemma3 where: |
| 12 | +- embed_tokens.weight is inside the exported language_model |
| 13 | +- lm_head.weight is outside (at parent level) |
| 14 | +- They share the same weight via _tied_weights_keys |
| 15 | +""" |
| 16 | + |
| 17 | +from typing import List, Set, Tuple, Type |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.nn as nn |
| 21 | + |
| 22 | +from ...models.factory import ModelFactory |
| 23 | +from ...shim.interface import CachedSequenceInterface |
| 24 | +from ...utils.logger import ad_logger |
| 25 | +from ..interface import ( |
| 26 | + BaseTransform, |
| 27 | + SharedConfig, |
| 28 | + TransformConfig, |
| 29 | + TransformInfo, |
| 30 | + TransformRegistry, |
| 31 | +) |
| 32 | + |
| 33 | + |
| 34 | +def _get_tied_weight_pairs(mod: nn.Module) -> List[Tuple[str, str]]: |
| 35 | + """Extract tied weight pairs from model's _tied_weights_keys attribute. |
| 36 | +
|
| 37 | + HF models can declare tied weights in multiple formats: |
| 38 | + 1. Dict format: {"lm_head.weight": "model.embed_tokens.weight"} - explicit dst->src mapping |
| 39 | + 2. List format: ["lm_head.weight"] - just lists the tied key, src is from get_input_embeddings() |
| 40 | +
|
| 41 | + For list format, we use get_input_embeddings() and get_output_embeddings() to determine |
| 42 | + the actual tying relationship. |
| 43 | +
|
| 44 | + Args: |
| 45 | + mod: The model to extract tied weight pairs from. |
| 46 | +
|
| 47 | + Returns: |
| 48 | + List of (dst_key, src_key) tuples where dst is tied TO src. |
| 49 | + Returns empty list if no tied weights are declared. |
| 50 | + """ |
| 51 | + tied_keys = getattr(mod, "_tied_weights_keys", None) |
| 52 | + if not tied_keys: |
| 53 | + return [] |
| 54 | + |
| 55 | + # Dict format: explicit mapping {"dst": "src"} |
| 56 | + if isinstance(tied_keys, dict): |
| 57 | + return list(tied_keys.items()) |
| 58 | + |
| 59 | + # List/set format: this typically means word embeddings are tied |
| 60 | + # Check config.tie_word_embeddings (HF's standard flag) to confirm |
| 61 | + if isinstance(tied_keys, (list, tuple, set)): |
| 62 | + # Check if tie_word_embeddings is enabled (HF's standard config flag) |
| 63 | + config = getattr(mod, "config", None) |
| 64 | + tie_word_embeddings = getattr(config, "tie_word_embeddings", None) |
| 65 | + |
| 66 | + # Also check text_config for VLM models |
| 67 | + if tie_word_embeddings is None and config is not None: |
| 68 | + text_config = getattr(config, "text_config", None) |
| 69 | + if text_config is not None: |
| 70 | + tie_word_embeddings = getattr(text_config, "tie_word_embeddings", None) |
| 71 | + |
| 72 | + if not tie_word_embeddings: |
| 73 | + ad_logger.debug( |
| 74 | + f"_tied_weights_keys={tied_keys} but tie_word_embeddings is not True, skipping" |
| 75 | + ) |
| 76 | + return [] |
| 77 | + |
| 78 | + # tie_word_embeddings=True and we have a list like ["lm_head.weight"] |
| 79 | + # Use HF's standard methods to find the actual tied modules |
| 80 | + input_embeddings = None |
| 81 | + output_embeddings = None |
| 82 | + input_embed_key = None |
| 83 | + output_embed_key = None |
| 84 | + |
| 85 | + try: |
| 86 | + if hasattr(mod, "get_input_embeddings"): |
| 87 | + input_embeddings = mod.get_input_embeddings() |
| 88 | + if hasattr(mod, "get_output_embeddings"): |
| 89 | + output_embeddings = mod.get_output_embeddings() |
| 90 | + except Exception: |
| 91 | + pass |
| 92 | + if input_embeddings is None or output_embeddings is None: |
| 93 | + ad_logger.warning( |
| 94 | + f"tie_word_embeddings=True but get_input_embeddings/get_output_embeddings " |
| 95 | + f"returned None (input={input_embeddings}, output={output_embeddings})" |
| 96 | + ) |
| 97 | + return [] |
| 98 | + |
| 99 | + # Find the parameter paths for input and output embeddings |
| 100 | + for name, submod in mod.named_modules(): |
| 101 | + if submod is input_embeddings: |
| 102 | + input_embed_key = f"{name}.weight" if name else "weight" |
| 103 | + if submod is output_embeddings: |
| 104 | + output_embed_key = f"{name}.weight" if name else "weight" |
| 105 | + if input_embed_key and output_embed_key and input_embed_key != output_embed_key: |
| 106 | + # output (lm_head) is tied TO input (embed_tokens) |
| 107 | + ad_logger.debug( |
| 108 | + f"Inferred tied weight pair: {output_embed_key} -> {input_embed_key} " |
| 109 | + f"(tie_word_embeddings=True)" |
| 110 | + ) |
| 111 | + return [(output_embed_key, input_embed_key)] |
| 112 | + |
| 113 | + ad_logger.warning( |
| 114 | + f"tie_word_embeddings=True but could not find embedding paths: " |
| 115 | + f"input={input_embed_key}, output={output_embed_key}" |
| 116 | + ) |
| 117 | + return [] |
| 118 | + |
| 119 | + return [] |
| 120 | + |
| 121 | + |
| 122 | +def _get_exported_submodule_keys(mod: nn.Module) -> List[str]: |
| 123 | + """Infer which submodules were exported by detecting GraphModules. |
| 124 | +
|
| 125 | + Args: |
| 126 | + mod: The root model to search for exported submodules. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + List of submodule key paths that are GraphModules (i.e., were exported). |
| 130 | + """ |
| 131 | + exported_keys = [] |
| 132 | + for name, submod in mod.named_modules(): |
| 133 | + if isinstance(submod, torch.fx.GraphModule): |
| 134 | + exported_keys.append(name) |
| 135 | + return exported_keys |
| 136 | + |
| 137 | + |
| 138 | +def _detect_cross_boundary_tied_weights( |
| 139 | + mod: nn.Module, |
| 140 | + exported_submodule_keys: List[str], |
| 141 | +) -> Tuple[List[Tuple[str, str]], Set[str]]: |
| 142 | + """Detect tied weights that cross the export boundary. |
| 143 | +
|
| 144 | + When a submodule is exported, weight tying between parameters inside and outside |
| 145 | + the exported submodule can break. This function identifies such cross-boundary pairs. |
| 146 | +
|
| 147 | + The exported parameter becomes the canonical source of truth because it's embedded |
| 148 | + in the GraphModule's graph (via get_attr nodes) and cannot be easily changed. |
| 149 | +
|
| 150 | + Args: |
| 151 | + mod: The root model containing both exported and non-exported submodules. |
| 152 | + exported_submodule_keys: List of submodule key paths that were exported. |
| 153 | +
|
| 154 | + Returns: |
| 155 | + Tuple of: |
| 156 | + - List of (dst_key, src_key) pairs that have cross-boundary tying |
| 157 | + - Set of canonical keys (the exported ones that are sources of truth) |
| 158 | + """ |
| 159 | + tied_pairs = _get_tied_weight_pairs(mod) |
| 160 | + if not tied_pairs: |
| 161 | + return [], set() |
| 162 | + |
| 163 | + def is_in_exported(key: str) -> bool: |
| 164 | + """Check if parameter key is inside an exported submodule.""" |
| 165 | + for sub in exported_submodule_keys: |
| 166 | + if sub == "": # Full model exported (root is GraphModule) |
| 167 | + return True |
| 168 | + if key.startswith(f"{sub}."): |
| 169 | + return True |
| 170 | + return False |
| 171 | + |
| 172 | + cross_boundary_pairs = [] |
| 173 | + canonical_keys = set() |
| 174 | + for dst_key, src_key in tied_pairs: |
| 175 | + src_exported = is_in_exported(src_key) |
| 176 | + dst_exported = is_in_exported(dst_key) |
| 177 | + |
| 178 | + if src_exported == dst_exported: |
| 179 | + # Both exported or both not exported - no cross-boundary issue |
| 180 | + # Existing deduplication handles both-exported case |
| 181 | + continue |
| 182 | + |
| 183 | + # Cross-boundary case: one exported, one not |
| 184 | + cross_boundary_pairs.append((dst_key, src_key)) |
| 185 | + |
| 186 | + # Determine which is canonical (exported) |
| 187 | + if src_exported: |
| 188 | + canonical_keys.add(src_key) |
| 189 | + else: |
| 190 | + canonical_keys.add(dst_key) |
| 191 | + |
| 192 | + return cross_boundary_pairs, canonical_keys |
| 193 | + |
| 194 | + |
| 195 | +def _sync_tied_weights( |
| 196 | + mod: nn.Module, |
| 197 | + cross_boundary_pairs: List[Tuple[str, str]], |
| 198 | + canonical_keys: Set[str], |
| 199 | +) -> int: |
| 200 | + """Sync tied weights by making non-canonical weights point to canonical weights. |
| 201 | +
|
| 202 | + This function should be called AFTER weights are loaded. It makes the non-exported |
| 203 | + weight (e.g., lm_head.weight) point to the same tensor as the exported weight |
| 204 | + (e.g., embed_tokens.weight). |
| 205 | +
|
| 206 | + Args: |
| 207 | + mod: The root model with loaded weights. |
| 208 | + cross_boundary_pairs: List of (dst_key, src_key) pairs with cross-boundary tying. |
| 209 | + canonical_keys: Set of parameter keys that are canonical (exported). |
| 210 | +
|
| 211 | + Returns: |
| 212 | + Number of weights successfully synced. |
| 213 | + """ |
| 214 | + synced_count = 0 |
| 215 | + for dst_key, src_key in cross_boundary_pairs: |
| 216 | + # Determine canonical vs redirect keys |
| 217 | + if src_key in canonical_keys: |
| 218 | + canonical_key = src_key |
| 219 | + redirect_key = dst_key |
| 220 | + else: |
| 221 | + canonical_key = dst_key |
| 222 | + redirect_key = src_key |
| 223 | + |
| 224 | + try: |
| 225 | + # Get the loaded canonical parameter |
| 226 | + canonical_param = mod.get_parameter(canonical_key) |
| 227 | + |
| 228 | + # Parse redirect key into module path and param name |
| 229 | + parts = redirect_key.rsplit(".", 1) |
| 230 | + if len(parts) > 1: |
| 231 | + redirect_mod = mod.get_submodule(parts[0]) |
| 232 | + redirect_name = parts[1] |
| 233 | + else: |
| 234 | + redirect_mod = mod |
| 235 | + redirect_name = parts[0] |
| 236 | + |
| 237 | + # Remove from _parameters so it's not a registered parameter |
| 238 | + # (prevents double-counting in state_dict, optimizer, etc.) |
| 239 | + if redirect_name in redirect_mod._parameters: |
| 240 | + del redirect_mod._parameters[redirect_name] |
| 241 | + |
| 242 | + # Sync: make redirect point to the canonical tensor |
| 243 | + setattr(redirect_mod, redirect_name, canonical_param) |
| 244 | + ad_logger.info(f"Synced tied weight: {redirect_key} -> {canonical_key} (canonical)") |
| 245 | + synced_count += 1 |
| 246 | + except Exception as e: |
| 247 | + ad_logger.warning(f"Failed to sync tied weight {redirect_key} -> {canonical_key}: {e}") |
| 248 | + |
| 249 | + return synced_count |
| 250 | + |
| 251 | + |
| 252 | +class SyncTiedWeightsConfig(TransformConfig): |
| 253 | + """Configuration for the sync tied weights transform.""" |
| 254 | + |
| 255 | + pass # No configuration options needed for now |
| 256 | + |
| 257 | + |
| 258 | +@TransformRegistry.register("sync_tied_weights") |
| 259 | +class SyncTiedWeights(BaseTransform): |
| 260 | + """Sync tied weights that cross the export boundary. |
| 261 | +
|
| 262 | + This transform runs AFTER weights are loaded (stage: post_load_fusion). |
| 263 | + It detects GraphModules to infer which submodules were exported, then |
| 264 | + syncs any tied weights that cross the export boundary. |
| 265 | +
|
| 266 | + For example, in Gemma3 VLM: |
| 267 | + - language_model is exported to GraphModule (contains embed_tokens.weight) |
| 268 | + - lm_head is at parent level (not exported) |
| 269 | + - _tied_weights_keys declares lm_head.weight -> embed_tokens.weight |
| 270 | + - This transform makes lm_head.weight reference embed_tokens.weight |
| 271 | + """ |
| 272 | + |
| 273 | + config: SyncTiedWeightsConfig |
| 274 | + |
| 275 | + @classmethod |
| 276 | + def get_config_class(cls) -> Type[TransformConfig]: |
| 277 | + return SyncTiedWeightsConfig |
| 278 | + |
| 279 | + def _apply_to_full_model( |
| 280 | + self, |
| 281 | + mod: nn.Module, |
| 282 | + cm: CachedSequenceInterface, |
| 283 | + factory: ModelFactory, |
| 284 | + shared_config: SharedConfig, |
| 285 | + ) -> Tuple[nn.Module, TransformInfo]: |
| 286 | + # Infer exported submodules by detecting GraphModules |
| 287 | + exported_keys = _get_exported_submodule_keys(mod) |
| 288 | + if not exported_keys: |
| 289 | + # No GraphModules found - nothing to sync |
| 290 | + return mod, TransformInfo( |
| 291 | + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True |
| 292 | + ) |
| 293 | + |
| 294 | + # Detect cross-boundary tied weights |
| 295 | + cross_boundary_pairs, canonical_keys = _detect_cross_boundary_tied_weights( |
| 296 | + mod, exported_keys |
| 297 | + ) |
| 298 | + |
| 299 | + if not cross_boundary_pairs: |
| 300 | + return mod, TransformInfo( |
| 301 | + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True |
| 302 | + ) |
| 303 | + |
| 304 | + # Directly sync the weights (weights are already loaded at this point) |
| 305 | + synced_count = _sync_tied_weights(mod, cross_boundary_pairs, canonical_keys) |
| 306 | + |
| 307 | + return mod, TransformInfo( |
| 308 | + skipped=False, |
| 309 | + num_matches=synced_count, |
| 310 | + is_clean=True, |
| 311 | + has_valid_shapes=True, |
| 312 | + ) |
0 commit comments