diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 86e847aca..3effb1d11 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -316,6 +316,7 @@ def main(args): mtq.quantize(child, disabled_quant_cfg, forward_loop=None) model = model.language_model + model_type = get_model_type(model) if args.sparsity_fmt != "dense": if args.batch_size == 0: diff --git a/examples/vlm_ptq/scripts/huggingface_example.sh b/examples/vlm_ptq/scripts/huggingface_example.sh index ea2733067..cbaf7c125 100755 --- a/examples/vlm_ptq/scripts/huggingface_example.sh +++ b/examples/vlm_ptq/scripts/huggingface_example.sh @@ -73,7 +73,7 @@ if [ -n "$KV_CACHE_QUANT" ]; then PTQ_ARGS+=" --kv_cache_qformat=$KV_CACHE_QUANT " fi -if [ "${MODEL_TYPE}" = "vila" ]; then +if [[ "${MODEL_NAME,,}" == *"vila"* ]]; then # Install required dependency for VILA pip install -r ../vlm_ptq/requirements-vila.txt # Clone original VILA repo