- 
                Notifications
    
You must be signed in to change notification settings  - Fork 190
 
Preserve original rope scaling type in export due to transformers library AutoConfig issue #452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0c4194e
              fdd82cb
              3782091
              b6941c7
              85e1235
              70842b8
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -16,13 +16,17 @@ | |
| """Common utils for the ModelConfig.""" | ||
| 
     | 
||
| import dataclasses | ||
| import json | ||
| import math | ||
| import warnings | ||
| from pathlib import Path | ||
| from types import UnionType | ||
| from typing import Union, get_args, get_origin | ||
| 
     | 
||
| import numpy as np | ||
| import torch | ||
| 
     | 
||
| from ..utils.model_path_utils import fetch_model_config, is_huggingface_model_id | ||
| from .model_config import ( | ||
| QUANTIZATION_FP8_PC_PT, | ||
| QUANTIZATION_INT4_AWQ, | ||
| 
          
            
          
           | 
    @@ -227,6 +231,76 @@ def model_config_from_dict(d: dict) -> ModelConfig: | |
| return _from_dict(config_type, d) | ||
| 
     | 
||
| 
     | 
||
| def restore_original_rope_scaling(config_data: dict, original_model_path: str) -> dict: | ||
| """Restore original rope_scaling configuration if it was modified by transformers. | ||
| 
     | 
||
| Some VLM models like Qwen2.5-VL have their rope_scaling configuration modified | ||
| by the transformers library during loading (e.g., from "mrope" to "default" with | ||
| additional fields). This function restores the original configuration. | ||
| 
     | 
||
| Args: | ||
| config_data: The model configuration dictionary to restore | ||
| original_model_path: Path to the original model directory or HuggingFace Hub model ID | ||
| (e.g., "microsoft/DialoGPT-medium" or "/path/to/local/model") | ||
| 
     | 
||
| Returns: | ||
| The config_data dictionary with restored rope_scaling (modified in-place) | ||
| 
     | 
||
| Note: | ||
| This function automatically detects whether original_model_path is a local filesystem | ||
| path or a HuggingFace Hub model ID. For Hub model IDs, it will fetch the config.json | ||
| directly from the Hub. Requires huggingface_hub package for Hub model ID support. | ||
| """ | ||
| try: | ||
| raw_original_config = None | ||
| 
     | 
||
| # Check if original_model_path is a HuggingFace Hub model ID or local path | ||
| if is_huggingface_model_id(original_model_path): | ||
| # Try to fetch config from HuggingFace Hub | ||
| raw_original_config = fetch_model_config(original_model_path) | ||
| else: | ||
| # Handle as local filesystem path | ||
| original_config_file = Path(original_model_path) / "config.json" | ||
| if original_config_file.exists(): | ||
| with open(original_config_file) as f: | ||
| raw_original_config = json.load(f) | ||
| 
     | 
||
| # If we successfully got the original config, proceed with restoration | ||
| if raw_original_config is not None: | ||
| # Check if rope_scaling was modified from mrope to default | ||
| if ( | ||
| "rope_scaling" in raw_original_config | ||
| and "rope_scaling" in config_data | ||
| and raw_original_config["rope_scaling"].get("type") == "mrope" | ||
| and config_data["rope_scaling"].get("type") == "default" | ||
| and "rope_type" in config_data["rope_scaling"] | ||
| ): | ||
| print(f"Restoring original rope_scaling configuration from {original_model_path}") | ||
| config_data["rope_scaling"] = raw_original_config["rope_scaling"] | ||
| 
     | 
||
| # Also restore rope_scaling in text_config if it exists | ||
| if ( | ||
| "text_config" in config_data | ||
| and "rope_scaling" in config_data["text_config"] | ||
| and config_data["text_config"]["rope_scaling"].get("type") == "default" | ||
| ): | ||
| config_data["text_config"]["rope_scaling"] = raw_original_config["rope_scaling"] | ||
| elif is_huggingface_model_id(original_model_path): | ||
| # Log that we couldn't find the original config | ||
| warnings.warn( | ||
| f"Could not fetch original config from HuggingFace Hub: {original_model_path}" | ||
| ) | ||
| else: | ||
| # Only warn if the local path was expected to exist | ||
| original_config_file = Path(original_model_path) / "config.json" | ||
| if not original_config_file.exists(): | ||
| warnings.warn(f"Original config file not found: {original_config_file}") | ||
| except Exception as e: | ||
| warnings.warn(f"Could not restore original rope_scaling configuration: {e}") | ||
| 
     | 
||
| return config_data | ||
| 
     | 
||
| 
         
      Comment on lines
    
      +234
     to 
      +302
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainRestore logic misses text_config-only rope_scaling; unify fetch and avoid print. If the original config has rope_scaling only under  Apply this diff to make restoration robust and logging consistent: @@
-def restore_original_rope_scaling(config_data: dict, original_model_path: str) -> dict:
+def restore_original_rope_scaling(config_data: dict, original_model_path: str) -> dict:
@@
-    try:
-        raw_original_config = None
-
-        # Check if original_model_path is a HuggingFace Hub model ID or local path
-        if is_huggingface_model_id(original_model_path):
-            # Try to fetch config from HuggingFace Hub
-            raw_original_config = fetch_model_config(original_model_path)
-        else:
-            # Handle as local filesystem path
-            original_config_file = Path(original_model_path) / "config.json"
-            if original_config_file.exists():
-                with open(original_config_file) as f:
-                    raw_original_config = json.load(f)
+    try:
+        # Always use unified fetcher (handles both Hub and local)
+        raw_original_config = fetch_model_config(original_model_path)
@@
-        if raw_original_config is not None:
-            # Check if rope_scaling was modified from mrope to default
-            if (
-                "rope_scaling" in raw_original_config
-                and "rope_scaling" in config_data
-                and raw_original_config["rope_scaling"].get("type") == "mrope"
-                and config_data["rope_scaling"].get("type") == "default"
-                and "rope_type" in config_data["rope_scaling"]
-            ):
-                print(f"Restoring original rope_scaling configuration from {original_model_path}")
-                config_data["rope_scaling"] = raw_original_config["rope_scaling"]
-
-                # Also restore rope_scaling in text_config if it exists
-                if (
-                    "text_config" in config_data
-                    and "rope_scaling" in config_data["text_config"]
-                    and config_data["text_config"]["rope_scaling"].get("type") == "default"
-                ):
-                    config_data["text_config"]["rope_scaling"] = raw_original_config["rope_scaling"]
-        elif is_huggingface_model_id(original_model_path):
-            # Log that we couldn't find the original config
-            warnings.warn(
-                f"Could not fetch original config from HuggingFace Hub: {original_model_path}"
-            )
-        else:
-            # Only warn if the local path was expected to exist
-            original_config_file = Path(original_model_path) / "config.json"
-            if not original_config_file.exists():
-                warnings.warn(f"Original config file not found: {original_config_file}")
+        if raw_original_config is not None:
+            # Original rope may live at top-level or under text_config
+            orig_rope = (
+                raw_original_config.get("rope_scaling")
+                or raw_original_config.get("text_config", {}).get("rope_scaling")
+            )
+            if isinstance(orig_rope, dict) and orig_rope.get("type") == "mrope":
+                def _should_restore(cur: dict | None) -> bool:
+                    return isinstance(cur, dict) and cur.get("type") == "default"
+
+                restored = False
+                # Top-level
+                if _should_restore(config_data.get("rope_scaling")):
+                    config_data["rope_scaling"] = orig_rope
+                    restored = True
+                # Nested text_config
+                if "text_config" in config_data and _should_restore(
+                    config_data["text_config"].get("rope_scaling")
+                ):
+                    config_data["text_config"]["rope_scaling"] = orig_rope
+                    restored = True
+                if restored:
+                    warnings.warn(
+                        f"Restored original rope_scaling (mrope) from {original_model_path}"
+                    )
+        else:
+            warnings.warn(f"Could not load original config for: {original_model_path}")
@@
-    except Exception as e:
+    except Exception as e:
         warnings.warn(f"Could not restore original rope_scaling configuration: {e}")🏁 Script executed: #!/bin/bash
# 1) Ensure there is only one definition of restore_original_rope_scaling
rg -nP --type=py '^\s*def\s+restore_original_rope_scaling\s*\('
# 2) Grep for call sites to confirm both top-level and text_config cases may occur
rg -n 'text_config"|'\''text_config'\''|\brope_scaling\b' --type=json -g 'config.json' -S -C1 || trueLength of output: 284 I need to verify whether  🏁 Script executed: #!/bin/bash
# 1) Find the definition of fetch_model_config
rg -nP --type=py '^\s*def\s+fetch_model_config\s*\(' -A 15
# 2) Check if current code uses print() anywhere
rg -n 'print\s*\(' modelopt/torch/export/model_config_utils.py
# 3) Check the imports at the top of the file to understand available helpers
head -50 modelopt/torch/export/model_config_utils.pyLength of output: 2873 🏁 Script executed: #!/bin/bash
# 1) Get full implementation of fetch_model_config
rg -nP --type=py '^\s*def\s+fetch_model_config\s*\(' -A 50 modelopt/torch/utils/model_path_utils.py | head -70
# 2) Extract the exact restoration logic from restore_original_rope_scaling
sed -n '250,280p' modelopt/torch/export/model_config_utils.pyLength of output: 3560 Replace redundant fetch logic with unified call; fix logic gap to restore text_config-only rope_scaling; replace print with warning. The review comment is correct. Verification confirms: 
 The suggested refactor correctly: 
 Apply the provided diff to resolve all three issues. 🤖 Prompt for AI Agents | 
||
| 
     | 
||
| def pad_weights(weights, tp_size): | ||
| """Returns the padded weights to tp_size.""" | ||
| assert len(weights.shape) > 1 | ||
| 
          
            
          
           | 
    ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just do a copy of the config.json and keep it the same as before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer the current solution as we can't keep it the same because we need to add the quantization config into
config.json.