Fine-tune a pre-trained Vision Transformer (ViT) on custom image classification tasks using HuggingFace Transformers, and run batch inference on new images.
- Overview
- Project Structure
- Prerequisites
- Setup
- Preparing Your Data
- Pipeline
- Running on a Supercomputer (SLURM)
- Configuration Reference
- Outputs
This project provides a three-step pipeline:
- Fine-tuning — Retrain the classification head (and optionally the backbone) of
google/vit-base-patch16-224-in21kon your own labeled images. - Prediction — Run the fine-tuned model on a folder of unlabeled images and save predictions to a CSV file.
- Confusion matrix — Evaluate the fine-tuned model on the labeled training data and produce accuracy, F1 score, and a confusion matrix plot.
The base model (google/vit-base-patch16-224-in21k) is pre-trained on ImageNet-21k (14 million images, 21k classes). Fine-tuning replaces the classification head with one matching your number of classes, while keeping the learned image representations from pre-training.
vit/
├── data/
│ ├── images_train/<task>/ # Labeled training images (subfolders per class)
│ └── images_inf/<task>/ # Unlabeled images for inference
├── model/
│ └── google-vit-base-patch16-224-in21k/ # Pre-trained base model (auto-downloaded)
├── output/
│ ├── finetuned_model/<task>/ # Fine-tuned model weights and config
│ ├── prediction/<task>/ # Prediction CSV
│ └── confusion_matrix/<task>/ # Confusion matrix plot
├── script/
│ ├── fine_tuning/
│ │ ├── fine_tuning.py # Fine-tuning script
│ │ ├── fine_tuning.sh # SLURM submission script
│ │ └── slurm/ # SLURM log output
│ └── prediction/
│ ├── 00_prediction.py # Batch inference script
│ ├── 00_prediction.sh # SLURM submission script
│ ├── 01_confusion_matrix.py # Evaluation script
│ ├── 01_confusion_matrix.sh # SLURM submission script
│ └── slurm/ # SLURM log output
├── pyproject.toml # Python dependencies (managed by uv)
└── .python-version # Python version (3.11.10)
- Python 3.11.10 (exact version required by
pyproject.toml) - uv — Python package and project manager
- GPU (optional) — Scripts automatically detect CUDA and fall back to CPU
Install uv if you don't have it:
curl -LsSf https://astral.sh/uv/install.sh | sh-
Clone the repository:
git clone <repository-url> cd vit
-
Create the virtual environment and install dependencies:
uv sync
This reads
pyproject.toml, creates a.venv/directory with Python 3.11.10, and installs all dependencies. If Python 3.11.10 is not installed on your machine, uv will download it automatically. -
Verify the installation:
uv run python -c "import transformers; print(transformers.__version__)"
The base model (google/vit-base-patch16-224-in21k) is downloaded automatically on the first run of fine_tuning.py and saved to model/google-vit-base-patch16-224-in21k/. No manual download is needed.
Organize your labeled images into subfolders inside data/images_train/<task>/, where each subfolder name is a class label:
data/images_train/pet_type/
├── cat/
│ ├── image001.jpg
│ ├── image002.jpg
│ └── ...
├── dog/
│ ├── image001.jpg
│ └── ...
└── bird/
├── image001.jpg
└── ...
- Each subfolder name becomes a class label (e.g.,
cat,dog,bird). - Supported image formats:
.jpg,.jpeg,.png,.bmp,.tiff,.webp. - Every class folder must contain at least one image.
- There is no minimum dataset size, but a few hundred images per class is recommended.
Place the unlabeled images you want to classify into data/images_inf/<task>/ as a flat folder (no subfolders needed):
data/images_inf/pet_type/
├── photo_a.jpg
├── photo_b.png
└── ...
All scripts use relative paths resolved from the repository root, so they work identically on a local machine and on a supercomputer — no path editing is needed.
Script: script/fine_tuning/fine_tuning.py
This script:
- Loads images from
data/images_train/<task>/using HuggingFace's ImageFolder loader. - Splits the data into training and validation sets (85/15 by default).
- Applies data augmentation (random crop + horizontal flip for training, center crop for validation).
- Fine-tunes the ViT model for the configured number of epochs.
- Saves the best model (by validation accuracy) to
output/finetuned_model/<task>/.
Run locally:
uv run script/fine_tuning/fine_tuning.pyWhat to expect: On CPU, a small dataset (~700 images, 4 epochs) takes roughly 40 minutes. On a GPU, the same run completes in a few minutes.
Before running, open script/fine_tuning/fine_tuning.py and set:
TASK_NAME— your task identifier (e.g.,"pet_type"). Must match the folder name underdata/images_train/.
Optional adjustments (all at the top of the script):
| Parameter | Default | Description |
|---|---|---|
BATCH_SIZE |
16 | Images per GPU per training step |
NUM_EPOCHS |
4 | Full passes over the training set |
LEARNING_RATE |
2e-4 | Step size for the AdamW optimizer |
TEST_SIZE |
0.15 | Fraction of data reserved for validation |
SAVE_STEPS |
100 | Save a checkpoint every N training steps |
EVAL_STEPS |
100 | Run validation every N training steps |
LOGGING_STEPS |
10 | Log training loss every N steps |
SAVE_TOTAL_LIMIT |
2 | Keep only the N most recent checkpoints |
Script: script/prediction/00_prediction.py
This script:
- Loads the fine-tuned model from
output/finetuned_model/<task>/. - Scans all images in
data/images_inf/<task>/. - Runs inference in batches and computes class probabilities.
- Saves a CSV file to
output/prediction/<task>/predictions.csv.
Run locally:
uv run script/prediction/00_prediction.pyBefore running, make sure:
- You have completed Step 1 (the fine-tuned model must exist).
- You have placed images in
data/images_inf/<task>/. TASK_NAMEin the script matches your task.
Output CSV format:
| filename | predicted_label | prob_cat | prob_dog | prob_bird |
|---|---|---|---|---|
| photo_a.jpg | cat | 0.9213 | 0.0512 | 0.0275 |
| photo_b.png | bird | 0.0134 | 0.0891 | 0.8975 |
The probability columns are named dynamically based on the class labels in your dataset.
Script: script/prediction/01_confusion_matrix.py
This script:
- Loads the fine-tuned model from
output/finetuned_model/<task>/. - Loads the labeled training images from
data/images_train/<task>/. - Runs inference on all labeled images and compares predictions to true labels.
- Prints accuracy and weighted F1 score to the terminal.
- Saves a confusion matrix plot to
output/confusion_matrix/<task>/confusion_matrix.png.
Run locally:
uv run script/prediction/01_confusion_matrix.pyBefore running, make sure:
- You have completed Step 1 (the fine-tuned model must exist).
TASK_NAMEin the script matches your task.
Each Python script has a matching .sh file configured for HPC usage. The shell scripts handle module loading, environment activation, and HuggingFace cache configuration.
-
Transfer the repository to your project directory (e.g.,
/project/home/p200804/vit/). -
Create the virtual environment:
module load env/release/2024.1 module load Python/3.11.10-GCCcore-13.3.0 uv sync
-
Create SLURM output directories:
mkdir -p script/fine_tuning/slurm mkdir -p script/prediction/slurm
-
Place your training images in
data/images_train/<task>/following the folder structure described above.
All sbatch commands must be run from the repository root:
cd /project/home/p200804/vit/
# Step 1: Fine-tune
sbatch script/fine_tuning/fine_tuning.sh
# Step 2: Predict (after fine-tuning completes)
sbatch script/prediction/00_prediction.sh
# Step 3: Confusion matrix (after fine-tuning completes)
sbatch script/prediction/01_confusion_matrix.sh# Check job status
squeue -u $USER
# View live output
tail -f script/fine_tuning/slurm/fine_tuning.out
tail -f script/prediction/slurm/00_prediction.out
tail -f script/prediction/slurm/01_confusion_matrix.out| Script | GPUs | CPUs | Time limit | Partition |
|---|---|---|---|---|
fine_tuning.sh |
4 | 32 | X hours | gpu |
00_prediction.sh |
1 | 32 | X hours | gpu |
01_confusion_matrix.sh |
1 | 32 | X hours | gpu |
Adjust the --time parameter in the .sh files if your dataset is significantly larger or smaller.
To adapt this project to a new classification task:
- Pick a task name (e.g.,
"age","breed","defect_type"). - Create the data folders:
mkdir -p data/images_train/<task_name>/<class_1> mkdir -p data/images_train/<task_name>/<class_2> mkdir -p data/images_inf/<task_name>
- Place your images in the corresponding folders.
- Update
TASK_NAMEin all three Python scripts:script/fine_tuning/fine_tuning.pyscript/prediction/00_prediction.pyscript/prediction/01_confusion_matrix.py
- Run the pipeline (Steps 1-3 above).
No other code changes are needed — the scripts dynamically detect class labels from subfolder names and configure the model accordingly.
After running the full pipeline, your output/ directory will contain:
output/
├── finetuned_model/<task>/ # Fine-tuned model
│ ├── config.json # Model architecture and label mapping
│ ├── model.safetensors # Model weights
│ ├── preprocessor_config.json # Image processor settings
│ ├── training_args.bin # Training configuration
│ ├── trainer_state.json # Training state and history
│ ├── train_results.json # Training metrics
│ ├── eval_results.json # Evaluation metrics
│ └── runs/ # TensorBoard logs
├── prediction/<task>/
│ └── predictions.csv # Inference results with probabilities
└── confusion_matrix/<task>/
└── confusion_matrix.png # Evaluation plot
To visualize training metrics with TensorBoard:
uv run tensorboard --logdir output/finetuned_model/<task>/runs