Skip to content

Commit 489e6db

Browse files
committed
preserve original rope scaling type in export due to transformers library AutoConfig issue
1 parent 8c6b915 commit 489e6db

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ def main(args):
249249
use_seq_device_map=args.use_seq_device_map,
250250
attn_implementation=args.attn_implementation,
251251
)
252+
# Store original model path for config restoration
253+
model._original_model_path = args.pyt_ckpt_path
252254
else:
253255
assert args.qformat in QUANT_CFG_CHOICES, (
254256
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
@@ -270,6 +272,8 @@ def main(args):
270272
args.pyt_ckpt_path,
271273
**model_kwargs,
272274
)
275+
# Store original model path for config restoration
276+
model._original_model_path = args.pyt_ckpt_path
273277
calibration_only = True
274278
model_is_already_quantized = is_quantized(model)
275279

modelopt/torch/export/model_config_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,54 @@ def model_config_from_dict(d: dict) -> ModelConfig:
227227
return _from_dict(config_type, d)
228228

229229

230+
def restore_original_rope_scaling(config_data: dict, original_model_path: str) -> dict:
231+
"""Restore original rope_scaling configuration if it was modified by transformers.
232+
233+
Some VLM models like Qwen2.5-VL have their rope_scaling configuration modified
234+
by the transformers library during loading (e.g., from "mrope" to "default" with
235+
additional fields). This function restores the original configuration.
236+
237+
Args:
238+
config_data: The model configuration dictionary to restore
239+
original_model_path: Path to the original model directory
240+
241+
Returns:
242+
The config_data dictionary with restored rope_scaling (modified in-place)
243+
"""
244+
import json
245+
import warnings
246+
from pathlib import Path
247+
248+
try:
249+
original_config_file = Path(original_model_path) / "config.json"
250+
if original_config_file.exists():
251+
with open(original_config_file) as f:
252+
raw_original_config = json.load(f)
253+
254+
# Check if rope_scaling was modified from mrope to default
255+
if (
256+
"rope_scaling" in raw_original_config
257+
and "rope_scaling" in config_data
258+
and raw_original_config["rope_scaling"].get("type") == "mrope"
259+
and config_data["rope_scaling"].get("type") == "default"
260+
and "rope_type" in config_data["rope_scaling"]
261+
):
262+
print(f"Restoring original rope_scaling configuration from {original_model_path}")
263+
config_data["rope_scaling"] = raw_original_config["rope_scaling"]
264+
265+
# Also restore rope_scaling in text_config if it exists
266+
if (
267+
"text_config" in config_data
268+
and "rope_scaling" in config_data["text_config"]
269+
and config_data["text_config"]["rope_scaling"].get("type") == "default"
270+
):
271+
config_data["text_config"]["rope_scaling"] = raw_original_config["rope_scaling"]
272+
except Exception as e:
273+
warnings.warn(f"Could not restore original rope_scaling configuration: {e}")
274+
275+
return config_data
276+
277+
230278
def pad_weights(weights, tp_size):
231279
"""Returns the padded weights to tp_size."""
232280
assert len(weights.shape) > 1

modelopt/torch/export/unified_export_hf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
QUANTIZATION_W4A8_AWQ,
5555
QUANTIZATION_W4A8_NVFP4_FP8,
5656
)
57+
from .model_config_utils import restore_original_rope_scaling
5758
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
5859
from .quant_utils import (
5960
fuse_prequant_layernorm,
@@ -541,6 +542,11 @@ def export_hf_checkpoint(
541542
with open(original_config) as file:
542543
config_data = json.load(file)
543544

545+
# Preserve original rope_scaling configuration if it was modified by transformers
546+
original_model_path = getattr(model, "_original_model_path", None)
547+
if original_model_path is not None:
548+
config_data = restore_original_rope_scaling(config_data, original_model_path)
549+
544550
config_data["quantization_config"] = hf_quant_config
545551

546552
with open(original_config, "w") as file:

0 commit comments

Comments
 (0)