Skip to content

Commit 0841a0e

Browse files
Edwardf0t1yeyu-nvidia
authored andcommitted
[Bugfix] Enable essential python files export in llm_ptq examples by adding a customized copy function (#351)
Signed-off-by: Zhiyu Cheng <[email protected]> Signed-off-by: Ye Yu <[email protected]>
1 parent 2d266d3 commit 0841a0e

File tree

2 files changed

+160
-1
lines changed

2 files changed

+160
-1
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import glob
1617
import os
18+
import shutil
1719
import sys
1820
import warnings
21+
from pathlib import Path
1922
from typing import Any
2023

2124
import torch
@@ -24,6 +27,11 @@
2427
from accelerate.utils import get_max_memory
2528
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
2629

30+
try:
31+
from huggingface_hub import snapshot_download
32+
except ImportError:
33+
snapshot_download = None
34+
2735
from modelopt.torch.utils.image_processor import MllamaImageProcessor
2836

2937
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
@@ -253,3 +261,141 @@ def apply_kv_cache_quant(quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str
253261
quant_cfg["algorithm"] = "max"
254262

255263
return quant_cfg
264+
265+
266+
def _resolve_model_path(model_name_or_path: str, trust_remote_code: bool = False) -> str:
267+
"""Resolve a model name or path to a local directory path.
268+
269+
If the input is already a local directory, returns it as-is.
270+
If the input is a HuggingFace model ID, attempts to resolve it to the local cache path.
271+
272+
Args:
273+
model_name_or_path: Either a local directory path or HuggingFace model ID
274+
trust_remote_code: Whether to trust remote code when loading the model
275+
276+
Returns:
277+
Local directory path to the model files
278+
"""
279+
# If it's already a local directory, return as-is
280+
if os.path.isdir(model_name_or_path):
281+
return model_name_or_path
282+
283+
# Try to resolve HuggingFace model ID to local cache path
284+
try:
285+
# First try to load the config to trigger caching
286+
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
287+
288+
# The config object should have the local path information
289+
# Try different ways to get the cached path
290+
if hasattr(config, "_name_or_path") and os.path.isdir(config._name_or_path):
291+
return config._name_or_path
292+
293+
# Alternative: use snapshot_download if available
294+
if snapshot_download is not None:
295+
try:
296+
local_path = snapshot_download(
297+
repo_id=model_name_or_path,
298+
allow_patterns=["*.py", "*.json"], # Only download Python files and config
299+
)
300+
return local_path
301+
except Exception as e:
302+
print(f"Warning: Could not download model files using snapshot_download: {e}")
303+
304+
# Fallback: try to find in HuggingFace cache
305+
from transformers.utils import TRANSFORMERS_CACHE
306+
307+
# Look for the model in the cache directory
308+
cache_pattern = os.path.join(TRANSFORMERS_CACHE, "models--*")
309+
cache_dirs = glob.glob(cache_pattern)
310+
311+
# Convert model name to cache directory format
312+
model_cache_name = model_name_or_path.replace("/", "--")
313+
for cache_dir in cache_dirs:
314+
if model_cache_name in cache_dir:
315+
# Look for the snapshots directory
316+
snapshots_dir = os.path.join(cache_dir, "snapshots")
317+
if os.path.exists(snapshots_dir):
318+
# Get the latest snapshot
319+
snapshot_dirs = [
320+
d
321+
for d in os.listdir(snapshots_dir)
322+
if os.path.isdir(os.path.join(snapshots_dir, d))
323+
]
324+
if snapshot_dirs:
325+
latest_snapshot = max(snapshot_dirs) # Use lexicographically latest
326+
snapshot_path = os.path.join(snapshots_dir, latest_snapshot)
327+
return snapshot_path
328+
329+
except Exception as e:
330+
print(f"Warning: Could not resolve model path for {model_name_or_path}: {e}")
331+
332+
# If all else fails, return the original path
333+
# This will cause the copy function to skip with a warning
334+
return model_name_or_path
335+
336+
337+
def copy_custom_model_files(source_path: str, export_path: str, trust_remote_code: bool = False):
338+
"""Copy custom model files (configuration_*.py, modeling_*.py, *.json, etc.) from source to export directory.
339+
340+
This function copies custom Python files and JSON configuration files that are needed for
341+
models with custom code. It excludes config.json and model.safetensors.index.json as these
342+
are typically handled separately by the model export process.
343+
344+
Args:
345+
source_path: Path to the original model directory or HuggingFace model ID
346+
export_path: Path to the exported model directory
347+
trust_remote_code: Whether trust_remote_code was used (only copy files if True)
348+
"""
349+
if not trust_remote_code:
350+
return
351+
352+
# Resolve the source path (handles both local paths and HF model IDs)
353+
resolved_source_path = _resolve_model_path(source_path, trust_remote_code)
354+
355+
source_dir = Path(resolved_source_path)
356+
export_dir = Path(export_path)
357+
358+
if not source_dir.exists():
359+
if resolved_source_path != source_path:
360+
print(
361+
f"Warning: Could not find local cache for HuggingFace model '{source_path}' "
362+
f"(resolved to '{resolved_source_path}')"
363+
)
364+
else:
365+
print(f"Warning: Source directory '{source_path}' does not exist")
366+
return
367+
368+
if not export_dir.exists():
369+
print(f"Warning: Export directory {export_path} does not exist")
370+
return
371+
372+
# Common patterns for custom model files that need to be copied
373+
custom_file_patterns = [
374+
"configuration_*.py",
375+
"modeling_*.py",
376+
"tokenization_*.py",
377+
"processing_*.py",
378+
"image_processing_*.py",
379+
"feature_extraction_*.py",
380+
"*.json",
381+
]
382+
383+
copied_files = []
384+
for pattern in custom_file_patterns:
385+
for file_path in source_dir.glob(pattern):
386+
if file_path.is_file():
387+
# Skip config.json and model.safetensors.index.json as they're handled separately
388+
if file_path.name in ["config.json", "model.safetensors.index.json"]:
389+
continue
390+
dest_path = export_dir / file_path.name
391+
try:
392+
shutil.copy2(file_path, dest_path)
393+
copied_files.append(file_path.name)
394+
print(f"Copied custom model file: {file_path.name}")
395+
except Exception as e:
396+
print(f"Warning: Failed to copy {file_path.name}: {e}")
397+
398+
if copied_files:
399+
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
400+
else:
401+
print("No custom model files found to copy")

examples/llm_ptq/hf_ptq.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323
import numpy as np
2424
import torch
2525
from accelerate.hooks import remove_hook_from_module
26-
from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec
26+
from example_utils import (
27+
apply_kv_cache_quant,
28+
copy_custom_model_files,
29+
get_model,
30+
get_processor,
31+
get_tokenizer,
32+
is_enc_dec,
33+
)
2734
from transformers import (
2835
AutoConfig,
2936
AutoModelForCausalLM,
@@ -604,6 +611,9 @@ def output_decode(generated_ids, input_shape):
604611
inference_tensor_parallel=args.inference_tensor_parallel,
605612
inference_pipeline_parallel=args.inference_pipeline_parallel,
606613
)
614+
615+
# Copy custom model files (Python files and JSON configs) for TensorRT-LLM export
616+
copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code)
607617
else:
608618
# Check arguments for unified_hf export format and set to default if unsupported arguments are provided
609619
assert args.sparsity_fmt == "dense", (
@@ -621,6 +631,9 @@ def output_decode(generated_ids, input_shape):
621631
export_dir=export_path,
622632
)
623633

634+
# Copy custom model files (Python files and JSON configs) if trust_remote_code is used
635+
copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code)
636+
624637
# Restore default padding and export the tokenizer as well.
625638
if tokenizer is not None:
626639
tokenizer.padding_side = default_padding_side

0 commit comments

Comments
 (0)