Skip to content

Commit 8e6dea3

Browse files
committed
update
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent c572513 commit 8e6dea3

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,18 @@ def get_model(
201201
# Prepare config kwargs for loading
202202
config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
203203

204-
# Special handling for vision-language models that may have device mapping issues
204+
# Load config once and handle VL model detection
205205
try:
206-
hf_config_check = AutoConfig.from_pretrained(ckpt_path, **config_kwargs)
207-
if _is_multimodal_config(hf_config_check):
206+
hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs)
207+
if _is_multimodal_config(hf_config):
208208
print(
209209
"Detected vision-language model from config. "
210210
"Disabling automatic device mapping to avoid device_map errors."
211211
)
212212
device_map = None
213213
except Exception as e:
214-
print(f"Warning: Could not load config for VL detection: {e}")
215-
print("Model loading will likely fail. Please check the model path and configuration.")
214+
print(f"Error: Could not load config from {ckpt_path}: {e}")
215+
raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e
216216
if attn_implementation is not None:
217217
config_kwargs["attn_implementation"] = attn_implementation
218218

@@ -234,11 +234,6 @@ def get_model(
234234
)
235235
model = hf_vila.llm
236236
else:
237-
hf_config = AutoConfig.from_pretrained(
238-
ckpt_path,
239-
**config_kwargs,
240-
)
241-
242237
if use_seq_device_map:
243238
device_map = "sequential"
244239
# If we use sequential, set max_memory limit to ensure that the model does not occupy the full GPU

0 commit comments

Comments
 (0)