Skip to content

Commit 4aa4285

Browse files
committed
add a util function to resolve model path
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 3b0f98a commit 4aa4285

File tree

1 file changed

+88
-3
lines changed

1 file changed

+88
-3
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import glob
1617
import os
1718
import shutil
1819
import sys
@@ -26,6 +27,11 @@
2627
from accelerate.utils import get_max_memory
2728
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
2829

30+
try:
31+
from huggingface_hub import snapshot_download
32+
except ImportError:
33+
snapshot_download = None
34+
2935
from modelopt.torch.utils.image_processor import MllamaImageProcessor
3036

3137
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
@@ -267,22 +273,101 @@ def apply_kv_cache_quant(quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str
267273
return quant_cfg
268274

269275

276+
def _resolve_model_path(model_name_or_path: str, trust_remote_code: bool = False) -> str:
277+
"""Resolve a model name or path to a local directory path.
278+
279+
If the input is already a local directory, returns it as-is.
280+
If the input is a HuggingFace model ID, attempts to resolve it to the local cache path.
281+
282+
Args:
283+
model_name_or_path: Either a local directory path or HuggingFace model ID
284+
trust_remote_code: Whether to trust remote code when loading the model
285+
286+
Returns:
287+
Local directory path to the model files
288+
"""
289+
# If it's already a local directory, return as-is
290+
if os.path.isdir(model_name_or_path):
291+
return model_name_or_path
292+
293+
# Try to resolve HuggingFace model ID to local cache path
294+
try:
295+
# First try to load the config to trigger caching
296+
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
297+
298+
# The config object should have the local path information
299+
# Try different ways to get the cached path
300+
if hasattr(config, "_name_or_path") and os.path.isdir(config._name_or_path):
301+
return config._name_or_path
302+
303+
# Alternative: use snapshot_download if available
304+
if snapshot_download is not None:
305+
try:
306+
local_path = snapshot_download(
307+
repo_id=model_name_or_path,
308+
allow_patterns=["*.py", "*.json"], # Only download Python files and config
309+
)
310+
return local_path
311+
except Exception as e:
312+
print(f"Warning: Could not download model files using snapshot_download: {e}")
313+
314+
# Fallback: try to find in HuggingFace cache
315+
from transformers.utils import TRANSFORMERS_CACHE
316+
317+
# Look for the model in the cache directory
318+
cache_pattern = os.path.join(TRANSFORMERS_CACHE, "models--*")
319+
cache_dirs = glob.glob(cache_pattern)
320+
321+
# Convert model name to cache directory format
322+
model_cache_name = model_name_or_path.replace("/", "--")
323+
for cache_dir in cache_dirs:
324+
if model_cache_name in cache_dir:
325+
# Look for the snapshots directory
326+
snapshots_dir = os.path.join(cache_dir, "snapshots")
327+
if os.path.exists(snapshots_dir):
328+
# Get the latest snapshot
329+
snapshot_dirs = [
330+
d
331+
for d in os.listdir(snapshots_dir)
332+
if os.path.isdir(os.path.join(snapshots_dir, d))
333+
]
334+
if snapshot_dirs:
335+
latest_snapshot = max(snapshot_dirs) # Use lexicographically latest
336+
snapshot_path = os.path.join(snapshots_dir, latest_snapshot)
337+
return snapshot_path
338+
339+
except Exception as e:
340+
print(f"Warning: Could not resolve model path for {model_name_or_path}: {e}")
341+
342+
# If all else fails, return the original path
343+
# This will cause the copy function to skip with a warning
344+
return model_name_or_path
345+
346+
270347
def copy_custom_model_files(source_path: str, export_path: str, trust_remote_code: bool = False):
271348
"""Copy custom model files (configuration_*.py, modeling_*.py, etc.) from source to export directory.
272349
273350
Args:
274-
source_path: Path to the original model directory
351+
source_path: Path to the original model directory or HuggingFace model ID
275352
export_path: Path to the exported model directory
276353
trust_remote_code: Whether trust_remote_code was used (only copy files if True)
277354
"""
278355
if not trust_remote_code:
279356
return
280357

281-
source_dir = Path(source_path)
358+
# Resolve the source path (handles both local paths and HF model IDs)
359+
resolved_source_path = _resolve_model_path(source_path, trust_remote_code)
360+
361+
source_dir = Path(resolved_source_path)
282362
export_dir = Path(export_path)
283363

284364
if not source_dir.exists():
285-
print(f"Warning: Source directory {source_path} does not exist")
365+
if resolved_source_path != source_path:
366+
print(
367+
f"Warning: Could not find local cache for HuggingFace model '{source_path}' (resolved to '{resolved_source_path}')"
368+
)
369+
else:
370+
print(f"Warning: Source directory '{source_path}' does not exist")
286371
return
287372

288373
if not export_dir.exists():

0 commit comments

Comments
 (0)