Skip to content

Commit 99fcd06

Browse files
yueshen2016mxinO
authored andcommitted
[BUG FIX 5616904]: Make VILA codebase importable and import configuration before load model config (#511)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> But fix **Overview:** ? Make VILA codebase importable and import configuration before load model config Fix https://nvbugspro.nvidia.com/bug/5616904 ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ``` CUDA_VISIBLE_DEVICES=0 bash -e scripts/huggingface_example.sh --model /models/vila1.5-3b --quant int4_awq --tp 1 --pp 1 --trust_remote_code --kv_cache_free_gpu_memory_fraction 0.5 ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> Signed-off-by: Yue <[email protected]> Signed-off-by: mxin <[email protected]>
1 parent 805d2ea commit 99fcd06

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,13 @@ def get_model(
270270
if device == "cpu":
271271
device_map = "cpu"
272272

273+
# Add VILA to sys.path before loading config if needed
274+
if "vila" in ckpt_path.lower():
275+
vila_path = os.path.join(ckpt_path, "..", "VILA")
276+
if vila_path not in sys.path:
277+
sys.path.append(vila_path)
278+
from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa: F401
279+
273280
# Prepare config kwargs for loading
274281
config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
275282

@@ -295,8 +302,6 @@ def get_model(
295302
model_kwargs.setdefault("torch_dtype", "auto")
296303

297304
if "vila" in ckpt_path.lower():
298-
sys.path.append(os.path.join(ckpt_path, "..", "VILA"))
299-
from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa: F401
300305
from transformers import AutoModel
301306

302307
hf_vila = AutoModel.from_pretrained(

0 commit comments

Comments
 (0)