Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/vlm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ if [ -n "$KV_CACHE_QUANT" ]; then
fi

if [[ "${MODEL_NAME,,}" == *"vila"* ]]; then
# Save current transformers version for later restoration
ORIGINAL_TRANSFORMERS_VERSION=$(pip show transformers | grep Version | cut -d' ' -f2)
echo "Current transformers version: $ORIGINAL_TRANSFORMERS_VERSION"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we just do a check and prompt user to install the supported version (if not, exit)?


# Install required dependency for VILA
pip install -r ../vlm_ptq/requirements-vila.txt
# Clone original VILA repo
Expand Down Expand Up @@ -103,6 +107,29 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
fi
fi

# Fix model_type for VILA models (they are based on LLaVA architecture)
# VILA models need to be recognized as llava_llama for TensorRT-LLM multimodal support
if [[ "${MODEL_NAME,,}" == *"vila"* ]] && [ -f "$MODEL_CONFIG" ]; then
echo "Updating model_type in config for VILA model..."
python3 -c "
import json
with open('$MODEL_CONFIG', 'r') as f:
config = json.load(f)
if config.get('model_type') == 'llama':
config['model_type'] = 'llava_llama'
with open('$MODEL_CONFIG', 'w') as f:
json.dump(config, f, indent=4)
print('Updated model_type from llama to llava_llama in $MODEL_CONFIG')
"
fi

# Restore original transformers version immediately after PTQ for VILA models
if [[ "${MODEL_NAME,,}" == *"vila"* ]] && [ -n "$ORIGINAL_TRANSFORMERS_VERSION" ]; then
echo "Restoring original transformers version: $ORIGINAL_TRANSFORMERS_VERSION"
pip install transformers==$ORIGINAL_TRANSFORMERS_VERSION
echo "Transformers version restored successfully."
fi

if [[ "$QFORMAT" != "fp8" ]]; then
echo "For quant format $QFORMAT, please refer to the TensorRT-LLM documentation for deployment. Checkpoint saved to $SAVE_PATH."
exit 0
Expand Down