diff --git a/examples/vlm_ptq/scripts/huggingface_example.sh b/examples/vlm_ptq/scripts/huggingface_example.sh index cbaf7c125..ce5f2305e 100755 --- a/examples/vlm_ptq/scripts/huggingface_example.sh +++ b/examples/vlm_ptq/scripts/huggingface_example.sh @@ -74,6 +74,10 @@ if [ -n "$KV_CACHE_QUANT" ]; then fi if [[ "${MODEL_NAME,,}" == *"vila"* ]]; then + # Save current transformers version for later restoration + ORIGINAL_TRANSFORMERS_VERSION=$(pip show transformers | grep Version | cut -d' ' -f2) + echo "Current transformers version: $ORIGINAL_TRANSFORMERS_VERSION" + # Install required dependency for VILA pip install -r ../vlm_ptq/requirements-vila.txt # Clone original VILA repo @@ -103,6 +107,29 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH fi fi +# Fix model_type for VILA models (they are based on LLaVA architecture) +# VILA models need to be recognized as llava_llama for TensorRT-LLM multimodal support +if [[ "${MODEL_NAME,,}" == *"vila"* ]] && [ -f "$MODEL_CONFIG" ]; then + echo "Updating model_type in config for VILA model..." + python3 -c " +import json +with open('$MODEL_CONFIG', 'r') as f: + config = json.load(f) +if config.get('model_type') == 'llama': + config['model_type'] = 'llava_llama' + with open('$MODEL_CONFIG', 'w') as f: + json.dump(config, f, indent=4) + print('Updated model_type from llama to llava_llama in $MODEL_CONFIG') +" +fi + +# Restore original transformers version immediately after PTQ for VILA models +if [[ "${MODEL_NAME,,}" == *"vila"* ]] && [ -n "$ORIGINAL_TRANSFORMERS_VERSION" ]; then + echo "Restoring original transformers version: $ORIGINAL_TRANSFORMERS_VERSION" + pip install transformers==$ORIGINAL_TRANSFORMERS_VERSION + echo "Transformers version restored successfully." +fi + if [[ "$QFORMAT" != "fp8" ]]; then echo "For quant format $QFORMAT, please refer to the TensorRT-LLM documentation for deployment. Checkpoint saved to $SAVE_PATH." exit 0