@@ -74,6 +74,10 @@ if [ -n "$KV_CACHE_QUANT" ]; then
7474fi
7575
7676if [[ " ${MODEL_NAME,,} " == * " vila" * ]]; then
77+ # Save current transformers version for later restoration
78+ ORIGINAL_TRANSFORMERS_VERSION=$( pip show transformers | grep Version | cut -d' ' -f2)
79+ echo " Current transformers version: $ORIGINAL_TRANSFORMERS_VERSION "
80+
7781 # Install required dependency for VILA
7882 pip install -r ../vlm_ptq/requirements-vila.txt
7983 # Clone original VILA repo
@@ -103,6 +107,29 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
103107 fi
104108fi
105109
110+ # Fix model_type for VILA models (they are based on LLaVA architecture)
111+ # VILA models need to be recognized as llava_llama for TensorRT-LLM multimodal support
112+ if [[ " ${MODEL_NAME,,} " == * " vila" * ]] && [ -f " $MODEL_CONFIG " ]; then
113+ echo " Updating model_type in config for VILA model..."
114+ python3 -c "
115+ import json
116+ with open('$MODEL_CONFIG ', 'r') as f:
117+ config = json.load(f)
118+ if config.get('model_type') == 'llama':
119+ config['model_type'] = 'llava_llama'
120+ with open('$MODEL_CONFIG ', 'w') as f:
121+ json.dump(config, f, indent=4)
122+ print('Updated model_type from llama to llava_llama in $MODEL_CONFIG ')
123+ "
124+ fi
125+
126+ # Restore original transformers version immediately after PTQ for VILA models
127+ if [[ " ${MODEL_NAME,,} " == * " vila" * ]] && [ -n " $ORIGINAL_TRANSFORMERS_VERSION " ]; then
128+ echo " Restoring original transformers version: $ORIGINAL_TRANSFORMERS_VERSION "
129+ pip install transformers==$ORIGINAL_TRANSFORMERS_VERSION
130+ echo " Transformers version restored successfully."
131+ fi
132+
106133if [[ " $QFORMAT " != " fp8" ]]; then
107134 echo " For quant format $QFORMAT , please refer to the TensorRT-LLM documentation for deployment. Checkpoint saved to $SAVE_PATH ."
108135 exit 0
0 commit comments