This repository presents the PyTorch code for PiPViT.
Main Paper at arxiv: "PiPViT: Patch-based Visual Interpretable Prototypes for Retinal Image Analysis" introduces PiPViT in retinal OCT B-scans classification.
The paper is published in Biomedical Signal Processing and Control.
PiPViT is an interpretable prototype-based model that learns to classify retinal OCT B-scans using only image-level labels. Built on transformer architecture, PiPViT captures long-range dependencies between patches to learn robust, human-interpretable prototypes that approximate lesion extent using only image-level labels, helping clinicians to better understand diagnostic outcomes. Additionally, by leveraging contrastive learning and the flexibility of vision transformers in processing images in different resolutions, PiPViT generates activation maps that effectively localize biomarkers of any size.
PiPViT: The ViT encoder extracts patch representations from the input image at a given resolution. Pretraining is conducted using three different image resolutions, with a consistent patch size and adaptively resized positional embeddings.
The sequence of patches is reshaped into S × S × D feature maps, where the pooled value of each feature map represents presence scores. Contrastive learning, guided by alignment (LA), tanh-loss LT , and KoLeo (LKoLeo) losses, clusters similar features that might represent a single biomarker together in the latent space.
Tanh-loss prevents trivial solutions and feature collapse. Finally, a sparse linear layer connects learned part-prototypes to classes, making the model’s output interpretable as a scoring sheet.
- Python (3.8.5)
- CUDA (11.1)
- PyTorch (1.13)
- torchvision(1.14)
- timm (0.5.4)
- PIL(9.1.1)
- openCV (4.6.0)
- numpy (1.21.6)
The dataset used in the paper are publicly available. The code can be run on any dataset with image-level labels. The dataset should be organized in the following format:
dataset
│ test
│ │ class1
│ │ │ image1.png
│ │ │ image2.png
│ │ class2
│ │ │ image1.png
│ │ │ image2.png
│ train
│ │ class1
│ │ │ image1.png
│ │ │ image2.png
│ │ class2
│ │ │ image1.png
│ │ │ image2.png
│ val
│ │ class1
│ │ │ image1.png
│ │ │ image2.png
│ │ class2
│ │ │ image1.png
│ │ │ image2.png
The code for creating datasets is available in data_utils.py. The splits for the datasets used in the paper are available in the annotations folder.
PiPViT can be pretrained on any dataset, however, we used the same training data. The function for pretraining is pretrain in the train.py file. The code can be run using the following command:
The pretraining can be done in two ways:
# single scale pretraining that only benefits from contrastive learning
python Smain.py --config_path /base_path/PiPViTV2/config/Pretrain/Pretrain_224/OCTDrusen/Sconfig_patch16_224.yaml
# multi-scale pretraining that benefits from both contrastive learning and multi-resolution learning
python Smain_multi_scale_pretrain.py --config_path /base_path/PiPViTV2/config/Pretrain/Pretrain_Multi/OCTDrusen/Sconfig.yamlThe variables required for the training are defined in the configuration file. The configuration file for the pretraining is available in the config folder.
The resolutions used for the multi-resolution pretraining are defined under variable img_resolutions in Smain_multi_scale_pretrain.py.
PiPViT can be fine-tuned on any dataset. The function for fine-tuning PiPViT is train_val_all_losses in the train.py file. The code can be run using the following command:
python Smain.py --config_path /base_path/PiPViTV2/config/Train/train/OCTDrusen/Sconfig.yamlThe variables required for the training are defined in the configuration file. The configuration file for the fine-tuning is available in the config folder.
The code for interpreting the results is available in the Smain_vis.py and Smain_vis_Squares.py. The code can be run using the following command:
python Smain_vis.py --config_path /base_path/PiPViTV2/config/Interpretation/OCTDrusen/Sconfig.yamlThe code for drusen detection is available in the Drusen_Prototype_Eval.py. The code can be run using the following command:
python Drusen_Prototype_Eval.py --config_path /base_path/PiPViTV2/config/Vis/OCT5K/Sconfig.yamlNote: The data splits for OCTDrusen dataset is available in annotations folder along with the list of samples used for the drusen detection task for OCT5K dataset (all_bounding_boxes_drusen.csv).
More details on the dataset can be found in the paper.
If you find this code useful, please consider citing the following paper:
@article{pipvitoghbaie2025,
title={{PiPViT}: Patch-based Visual Interpretable Prototypes for Retinal Image Analysis},
author={Oghbaie, Marzieh and Araujo, Teresa and Schmidt-Erfurth, Ursula and Bogunovic, Hrvoje},
journal={Preprint},
year={2025}
}
PiPViT code is mainly based on PiP-Net, along with timm, and DINOv2. We thank the authors for making their code available.
For baseline models, we used the following repositories: