-
Notifications
You must be signed in to change notification settings - Fork 169
[Bugfix] Enable essential python files export in llm_ptq examples by adding a customized copy function #351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
39ccfad
3b0f98a
4aa4285
9d305fa
21fa47f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -13,9 +13,12 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
import glob | ||||||||||||||||||||||||||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||||||||||||||||||||||||||
import shutil | ||||||||||||||||||||||||||||||||||||||||||||||||||
import sys | ||||||||||||||||||||||||||||||||||||||||||||||||||
import warnings | ||||||||||||||||||||||||||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Any | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -24,6 +27,11 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||
from accelerate.utils import get_max_memory | ||||||||||||||||||||||||||||||||||||||||||||||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||
from huggingface_hub import snapshot_download | ||||||||||||||||||||||||||||||||||||||||||||||||||
except ImportError: | ||||||||||||||||||||||||||||||||||||||||||||||||||
snapshot_download = None | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
from modelopt.torch.utils.image_processor import MllamaImageProcessor | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -263,3 +271,133 @@ def apply_kv_cache_quant(quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str | |||||||||||||||||||||||||||||||||||||||||||||||||
quant_cfg["algorithm"] = "max" | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
return quant_cfg | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
def _resolve_model_path(model_name_or_path: str, trust_remote_code: bool = False) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||
"""Resolve a model name or path to a local directory path. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
If the input is already a local directory, returns it as-is. | ||||||||||||||||||||||||||||||||||||||||||||||||||
If the input is a HuggingFace model ID, attempts to resolve it to the local cache path. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
model_name_or_path: Either a local directory path or HuggingFace model ID | ||||||||||||||||||||||||||||||||||||||||||||||||||
trust_remote_code: Whether to trust remote code when loading the model | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||
Local directory path to the model files | ||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
# If it's already a local directory, return as-is | ||||||||||||||||||||||||||||||||||||||||||||||||||
if os.path.isdir(model_name_or_path): | ||||||||||||||||||||||||||||||||||||||||||||||||||
return model_name_or_path | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
# Try to resolve HuggingFace model ID to local cache path | ||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||
# First try to load the config to trigger caching | ||||||||||||||||||||||||||||||||||||||||||||||||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
# The config object should have the local path information | ||||||||||||||||||||||||||||||||||||||||||||||||||
# Try different ways to get the cached path | ||||||||||||||||||||||||||||||||||||||||||||||||||
if hasattr(config, "_name_or_path") and os.path.isdir(config._name_or_path): | ||||||||||||||||||||||||||||||||||||||||||||||||||
return config._name_or_path | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
# Alternative: use snapshot_download if available | ||||||||||||||||||||||||||||||||||||||||||||||||||
if snapshot_download is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||
local_path = snapshot_download( | ||||||||||||||||||||||||||||||||||||||||||||||||||
repo_id=model_name_or_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||
allow_patterns=["*.py", "*.json"], # Only download Python files and config | ||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||
return local_path | ||||||||||||||||||||||||||||||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"Warning: Could not download model files using snapshot_download: {e}") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
# Fallback: try to find in HuggingFace cache | ||||||||||||||||||||||||||||||||||||||||||||||||||
from transformers.utils import TRANSFORMERS_CACHE | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
# Look for the model in the cache directory | ||||||||||||||||||||||||||||||||||||||||||||||||||
cache_pattern = os.path.join(TRANSFORMERS_CACHE, "models--*") | ||||||||||||||||||||||||||||||||||||||||||||||||||
cache_dirs = glob.glob(cache_pattern) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
# Convert model name to cache directory format | ||||||||||||||||||||||||||||||||||||||||||||||||||
model_cache_name = model_name_or_path.replace("/", "--") | ||||||||||||||||||||||||||||||||||||||||||||||||||
for cache_dir in cache_dirs: | ||||||||||||||||||||||||||||||||||||||||||||||||||
if model_cache_name in cache_dir: | ||||||||||||||||||||||||||||||||||||||||||||||||||
# Look for the snapshots directory | ||||||||||||||||||||||||||||||||||||||||||||||||||
snapshots_dir = os.path.join(cache_dir, "snapshots") | ||||||||||||||||||||||||||||||||||||||||||||||||||
if os.path.exists(snapshots_dir): | ||||||||||||||||||||||||||||||||||||||||||||||||||
# Get the latest snapshot | ||||||||||||||||||||||||||||||||||||||||||||||||||
snapshot_dirs = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||
d | ||||||||||||||||||||||||||||||||||||||||||||||||||
for d in os.listdir(snapshots_dir) | ||||||||||||||||||||||||||||||||||||||||||||||||||
if os.path.isdir(os.path.join(snapshots_dir, d)) | ||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||
if snapshot_dirs: | ||||||||||||||||||||||||||||||||||||||||||||||||||
latest_snapshot = max(snapshot_dirs) # Use lexicographically latest | ||||||||||||||||||||||||||||||||||||||||||||||||||
snapshot_path = os.path.join(snapshots_dir, latest_snapshot) | ||||||||||||||||||||||||||||||||||||||||||||||||||
return snapshot_path | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"Warning: Could not resolve model path for {model_name_or_path}: {e}") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
# If all else fails, return the original path | ||||||||||||||||||||||||||||||||||||||||||||||||||
# This will cause the copy function to skip with a warning | ||||||||||||||||||||||||||||||||||||||||||||||||||
return model_name_or_path | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
def copy_custom_model_files(source_path: str, export_path: str, trust_remote_code: bool = False): | ||||||||||||||||||||||||||||||||||||||||||||||||||
"""Copy custom model files (configuration_*.py, modeling_*.py, etc.) from source to export directory. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
source_path: Path to the original model directory or HuggingFace model ID | ||||||||||||||||||||||||||||||||||||||||||||||||||
export_path: Path to the exported model directory | ||||||||||||||||||||||||||||||||||||||||||||||||||
trust_remote_code: Whether trust_remote_code was used (only copy files if True) | ||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
if not trust_remote_code: | ||||||||||||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
# Resolve the source path (handles both local paths and HF model IDs) | ||||||||||||||||||||||||||||||||||||||||||||||||||
resolved_source_path = _resolve_model_path(source_path, trust_remote_code) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
source_dir = Path(resolved_source_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||
export_dir = Path(export_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
if not source_dir.exists(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
if resolved_source_path != source_path: | ||||||||||||||||||||||||||||||||||||||||||||||||||
print( | ||||||||||||||||||||||||||||||||||||||||||||||||||
f"Warning: Could not find local cache for HuggingFace model '{source_path}' " | ||||||||||||||||||||||||||||||||||||||||||||||||||
f"(resolved to '{resolved_source_path}')" | ||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"Warning: Source directory '{source_path}' does not exist") | ||||||||||||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
if not export_dir.exists(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"Warning: Export directory {export_path} does not exist") | ||||||||||||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
# Common patterns for custom model files that need to be copied | ||||||||||||||||||||||||||||||||||||||||||||||||||
custom_file_patterns = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||
Edwardf0t1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||
"configuration_*.py", | ||||||||||||||||||||||||||||||||||||||||||||||||||
"modeling_*.py", | ||||||||||||||||||||||||||||||||||||||||||||||||||
"tokenization_*.py", | ||||||||||||||||||||||||||||||||||||||||||||||||||
"processing_*.py", | ||||||||||||||||||||||||||||||||||||||||||||||||||
"image_processing_*.py", | ||||||||||||||||||||||||||||||||||||||||||||||||||
"feature_extraction_*.py", | ||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
copied_files = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||
for pattern in custom_file_patterns: | ||||||||||||||||||||||||||||||||||||||||||||||||||
for file_path in source_dir.glob(pattern): | ||||||||||||||||||||||||||||||||||||||||||||||||||
if file_path.is_file(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
dest_path = export_dir / file_path.name | ||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||
shutil.copy2(file_path, dest_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||
copied_files.append(file_path.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"Copied custom model file: {file_path.name}") | ||||||||||||||||||||||||||||||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"Warning: Failed to copy {file_path.name}: {e}") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
393
to
407
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Preserve directory structure and recurse; current code may miss files and break imports
Apply this diff to recurse and keep relative paths: - for pattern in custom_file_patterns:
- for file_path in source_dir.glob(pattern):
- if file_path.is_file():
- dest_path = export_dir / file_path.name
- try:
- shutil.copy2(file_path, dest_path)
- copied_files.append(file_path.name)
- print(f"Copied custom model file: {file_path.name}")
- except Exception as e:
- print(f"Warning: Failed to copy {file_path.name}: {e}")
+ for pattern in custom_file_patterns:
+ for file_path in source_dir.rglob(pattern):
+ if file_path.is_file():
+ rel_path = file_path.relative_to(source_dir)
+ dest_path = export_dir / rel_path
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
+ try:
+ shutil.copy2(file_path, dest_path)
+ copied_files.append(str(rel_path))
+ print(f"Copied custom model file: {rel_path}")
+ except Exception as e:
+ print(f"Warning: Failed to copy {rel_path}: {e}") 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||
if copied_files: | ||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}") | ||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
print("No custom model files found to copy") |
Uh oh!
There was an error while loading. Please reload this page.