Skip to content

Commit 916016c

Browse files
committed
Add transformers restoration after ptq for vila
Signed-off-by: Yue <[email protected]>
1 parent 5adb9ba commit 916016c

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

examples/vlm_ptq/scripts/huggingface_example.sh

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ if [ -n "$KV_CACHE_QUANT" ]; then
7474
fi
7575

7676
if [[ "${MODEL_NAME,,}" == *"vila"* ]]; then
77+
# Save current transformers version for later restoration
78+
ORIGINAL_TRANSFORMERS_VERSION=$(pip show transformers | grep Version | cut -d' ' -f2)
79+
echo "Current transformers version: $ORIGINAL_TRANSFORMERS_VERSION"
80+
7781
# Install required dependency for VILA
7882
pip install -r ../vlm_ptq/requirements-vila.txt
7983
# Clone original VILA repo
@@ -103,6 +107,29 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
103107
fi
104108
fi
105109

110+
# Fix model_type for VILA models (they are based on LLaVA architecture)
111+
# VILA models need to be recognized as llava_llama for TensorRT-LLM multimodal support
112+
if [[ "${MODEL_NAME,,}" == *"vila"* ]] && [ -f "$MODEL_CONFIG" ]; then
113+
echo "Updating model_type in config for VILA model..."
114+
python3 -c "
115+
import json
116+
with open('$MODEL_CONFIG', 'r') as f:
117+
config = json.load(f)
118+
if config.get('model_type') == 'llama':
119+
config['model_type'] = 'llava_llama'
120+
with open('$MODEL_CONFIG', 'w') as f:
121+
json.dump(config, f, indent=4)
122+
print('Updated model_type from llama to llava_llama in $MODEL_CONFIG')
123+
"
124+
fi
125+
126+
# Restore original transformers version immediately after PTQ for VILA models
127+
if [[ "${MODEL_NAME,,}" == *"vila"* ]] && [ -n "$ORIGINAL_TRANSFORMERS_VERSION" ]; then
128+
echo "Restoring original transformers version: $ORIGINAL_TRANSFORMERS_VERSION"
129+
pip install transformers==$ORIGINAL_TRANSFORMERS_VERSION
130+
echo "Transformers version restored successfully."
131+
fi
132+
106133
if [[ "$QFORMAT" != "fp8" ]]; then
107134
echo "For quant format $QFORMAT, please refer to the TensorRT-LLM documentation for deployment. Checkpoint saved to $SAVE_PATH."
108135
exit 0

0 commit comments

Comments
 (0)