Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
53ea81e
added initial support for multi-dataset ood eval
fira7s Apr 9, 2025
9bae3eb
:sparkles: added scale ood eval method
fira7s Apr 14, 2025
3a069e9
:hammer: fixed ood table output and defined near and far ood
fira7s Apr 15, 2025
bf9c384
:hammer: restructure + added eval ood for imagenet
fira7s Apr 16, 2025
6ac1fff
Restructure + added various ood eval methods
fira7s Apr 23, 2025
1fb58d8
Changed ood splits
fira7s Apr 30, 2025
7f1d876
:hammer: Code restructure and fixes
fira7s May 7, 2025
30e05b4
:hammer: fixed missing dataset
fira7s May 7, 2025
ad05e35
:hammer: minor fixes
fira7s May 7, 2025
be802f0
:hammer: minor fixes
fira7s May 7, 2025
d994edc
:bug: Fix ImageNet Variation download
o-laurent May 14, 2025
bd80851
:hammer: fixed tests and ood eval
fira7s May 14, 2025
f9ab886
Merge remote-tracking branch 'origin/OodEval' into OodEval
fira7s May 14, 2025
85a89e9
Merge branch 'dev' into OodEval
fira7s May 14, 2025
f695223
:hammer: minor fixes
fira7s May 15, 2025
36211da
:hammer: fixed tmpscale/confromal for ensemble methods
fira7s May 16, 2025
622bb11
:hammer: fixed issues with postprocess+ensembling
fira7s May 19, 2025
8faf5a0
minor fixes
fira7s May 19, 2025
d4c0a00
:hammer: fixing tests
fira7s May 19, 2025
4607b0f
:hammer: minor fixes
fira7s May 23, 2025
8300f3b
:hammer: minor fixes
fira7s May 23, 2025
e26a274
Merge
fira7s May 23, 2025
c60b3c2
minor fixes
fira7s May 23, 2025
51be518
:hammer: ood eval fixes
fira7s Jun 2, 2025
bec0916
:hammer: fixed dependecies
fira7s Jun 4, 2025
2f219f5
:hammer: fixed ood splits
fira7s Jun 11, 2025
8c5d53d
:hammer: minor fixes
fira7s Jun 11, 2025
59a0240
:hammer: fix tests
fira7s Jun 18, 2025
fc449af
added vit code and tutorial
fira7s Jul 31, 2025
99c508a
added sst2 datamodule and nlp tutorial
fira7s Aug 5, 2025
3a99a1c
Merge branch 'dev' into OodEval
fira7s Aug 5, 2025
9d5b1f1
minor fixes
fira7s Aug 5, 2025
c67e951
Merge branch 'OodEval' of https://github.com/ENSTA-U2IS-AI/torch-unce…
fira7s Aug 5, 2025
d366ed0
fixed typo
fira7s Aug 5, 2025
1df5495
fixed dependecies
fira7s Aug 5, 2025
a57a62d
fixed dependecies
fira7s Aug 5, 2025
3385acc
fix formatting issue
fira7s Aug 18, 2025
35bc873
fixed imagenet test
fira7s Aug 18, 2025
fc3d807
tests fix attempt
fira7s Aug 18, 2025
e405cac
moved config files to hf
fira7s Aug 18, 2025
3b1035b
added more datamodule tests
fira7s Aug 21, 2025
b4c1b65
added more datamodules tests
fira7s Aug 21, 2025
ab659d0
added ood tests
fira7s Aug 26, 2025
37e2305
Merge branch 'dev' into OodEval
fira7s Aug 26, 2025
41b620d
updated ood tutorial
fira7s Aug 27, 2025
ee1d5ee
Merge branch 'OodEval' of https://github.com/ENSTA-U2IS-AI/torch-unce…
fira7s Aug 27, 2025
6f01210
minor fixes
fira7s Aug 27, 2025
d534ad3
minor fixes
fira7s Aug 28, 2025
77d9e5e
attempt to fix coverage
fira7s Aug 28, 2025
c235e9b
attempt to fix coverage
fira7s Sep 1, 2025
32121a2
attempt to fix coverage
fira7s Sep 1, 2025
61ae40f
attempt to fix coverage
fira7s Sep 1, 2025
d4e5664
attempt to fix coverage
fira7s Sep 1, 2025
74feb4b
attempt to fix coverage
fira7s Sep 1, 2025
619b365
attempt to fix coverage
fira7s Sep 12, 2025
db3f66b
attempt to fix coverage
fira7s Sep 12, 2025
0577694
attempt to fix coverage
fira7s Sep 12, 2025
cba5666
attempt to fix coverage
fira7s Sep 12, 2025
9a5d68e
attempt to fix coverage
fira7s Oct 10, 2025
036bda0
fixed far ood logging
fira7s Oct 10, 2025
a4d3715
fixed shift datasets for IM1K
fira7s Oct 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ repos:
name: check-added-large-files
entry: check-added-large-files
language: system
args: ['--maxkb=2048']
- id: ruff-check
name: ruff-check
entry: ruff check
Expand Down
224 changes: 224 additions & 0 deletions auto_tutorial_source/Bayesian_Methods/tuto_ood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""
Simple Ood Evaluation
================================================


In this tutorial, we’ll demonstrate how to perform out-of-distribution (OOD) evaluation using TorchUncertainty’s datamodules and routines. You’ll learn to:

1. **Set up a CIFAR-100 datamodule** that automatically handles in-distribution, near-OOD, and far-OOD splits.
2. **Run the `ClassificationRoutine`** to compute both in-distribution accuracy and OOD metrics (AUROC, AUPR, FPR95).
3. **Plug in your own OOD datasets** for fully custom evaluation.

Foreword on Out-of-Distribution Detection
-----------------------------------------

Out-of-Distribution (OOD) detection measures a model’s ability to recognize inputs that differ from its training distribution. TorchUncertainty integrates common OOD metrics directly into the Lightning test loop, including:

- **AUROC** (Area Under the ROC Curve)
- **AUPR** (Area Under the Precision-Recall Curve)
- **FPR95** (False Positive Rate at 95% True Positive Rate)

With just a few lines of code you can compare in-distribution performance to OOD detection performance under both “near” and “far” shifts. Per default, TorchUncertainty uses the
popular OpenOOD library to define the near and far OOD datasets and splits. You can also use your own datasets by passing them to the datamodule.

Supported Datamodules and Default OOD Splits
--------------------------------------------

.. list-table:: Datamodules & Default OOD Splits
:header-rows: 1
:widths: 20 15 20 20

* - **Datamodule**
- **In-Domain**
- **Default Near-OOD (Hard)**
- **Default Far-OOD (Easy)**
* - ``CIFAR10DataModule``
- CIFAR-10
- CIFAR-100, Tiny ImageNet
- MNIST, SVHN, Textures, Places365
* - ``CIFAR100DataModule``
- CIFAR-100
- CIFAR-10, Tiny ImageNet
- MNIST, SVHN, Textures, Places365
* - ``ImageNetDataModule``
- ImageNet-1K
- SSB-hard, NINCO
- iNaturalist, Textures, OpenImage-O
* - ``ImageNet200DataModule``
- ImageNet200
- SSB-hard, NINCO
- iNaturalist, Textures, OpenImage-O

Supported OOD Criteria
----------------------

.. list-table:: Supported OOD Criteria
:header-rows: 1
:widths: 15 50

* - **Criterion**
- **Original Reference (Year, Venue)**
* - ``msp``
- Hendrycks & Gimpel, A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks `ICLR Workshop 2017 <https://arxiv.org/abs/1610.02136>`_.
* - ``Maxlogit``
- /
* - ``energy``
- Liu et al., Energy-based Out-of-Distribution Detection `NeurIPS 2020 <https://arxiv.org/abs/2010.03759>`_.
* - ``odin``
- Liang, Li & Srikant, Enhancing The Reliability of Out-of-Distribution Image Detection in Neural Networks `ICML 2018 <https://arxiv.org/abs/1706.02690>`_.
* - ``entropy``
- /
* - ``mutual_information``
- /
* - ``variation_ratio``
- /
* - ``scale``
- Scaling Out-of-Distribution Detection for Real-World Settings Hendrycks et al. `ICML 2022 <https://proceedings.mlr.press/v162/hendrycks22a/hendrycks22a.pdf>`_.
* - ``ash``
- AASH: Extremely Simple Activation Shaping for OOD Detection, Djurisic et al. `ICLR 2023 <https://arxiv.org/pdf/2209.09858>`_.
* - ``react``
- ReAct: Out-of-distribution Detection with Rectified Activations, Sun et al. `NeurIPS 2021 <https://proceedings.neurips.cc/paper/2021/file/01894d6f048493d2cacde3c579c315a3-Paper.pdf>`_.
* - ``adascale_a``
- AdaSCALE: Adaptive Scaling for OOD Detection `Regmi et al. <https://arxiv.org/pdf/2503.08023>`_.
* - ``vim``
- ViM: Out-of-Distribution with Virtual-Logit Matching, Wang et al. `CVPR 2022 <https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_ViM_Out-of-Distribution_With_Virtual-Logit_Matching_CVPR_2022_paper.pdf>`_.
* - ``knn``
- Out-of-Distribution Detection with Deep Nearest Neighbors, Sun et al. `ICML 2022 <https://arxiv.org/abs/2106.01477>`_.
* - ``gen``
- GEN: Generalized ENtropy Score for OOD Detection, Liu et al. `CVPR 2023 <https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_GEN_Pushing_the_Limits_of_Softmax-Based_Out-of-Distribution_Detection_CVPR_2023_paper.pdf>`_.
* - ``nnguide``
- NNGuide: Nearest-Neighbor Guidance for OOD Detection, Park et al. `ICCV 2023 <https://openaccess.thecvf.com/content/ICCV2023/papers/Park_Nearest_Neighbor_Guidance_for_Out-of-Distribution_Detection_ICCV_2023_paper.pdf>`_.

.. note::

- All of these criteria can be passed as the `ood_criterion` argument to
`ClassificationRoutine`.
- Methods marked “ensemble-only” will require multiple stochastic passes.



.. note::

- **Near-OOD** splits are semantically similar to the in-domain data.
- **Far-OOD** splits come from more distant distributions (e.g., ImageNet variants).
- Override defaults by passing your own ``near_ood_datasets`` / ``far_ood_datasets``.


1. Loading the utilities
~~~~~~~~~~~~~~~~~~~~~~~~

To eval ood using TorchUncertainty, we have to load the following:

- the model:ResNet18_32x32 trained on in-domain data cifar100
- the classification routine from torch_uncertainty.routines
- the datamodule that handles dataloaders: CIFAR100DataModule from torch_uncertainty.datamodules.
"""

# %%
from pathlib import Path

# %%
# 2. Load the trained model
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# In this tutorial we will be loading a pretrained model, but you can also train your own using the same classification routine and still get ood related metrics at test phase.


import torch
from torch_uncertainty.models.resnet import resnet
from huggingface_hub import hf_hub_download

net = resnet(in_channels=3, arch=18, num_classes=100, style="cifar", conv_bias=False)

# load the model
path = hf_hub_download(repo_id="torch-uncertainty/resnet18_c100", filename="resnet18_c100.ckpt")
net.load_state_dict(torch.load(path))

net.cuda()
net.eval()


# %%
# 3. Defining the necessary datamodules
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# In the following, we instantiate our trainer, define the root of the datasets and the logs.
# We also create the datamodule that handles the cifar100 dataset, dataloaders and transforms.
# Datamodules can also handle OOD detection by setting the eval_ood parameter to True.

from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.routines import ClassificationRoutine
import torch.nn as nn
from pathlib import Path
from torch_uncertainty import TUTrainer


root = Path("data1")
datamodule = CIFAR100DataModule(root=root, batch_size=200, eval_ood=True, eval_shift=True)
trainer = TUTrainer(accelerator="gpu", enable_progress_bar=True)


# %%
# 4. Define the classification routine and launch the test
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Define the classification routine for evaluation. We use the CrossEntropyLoss
# as the loss function since we are working on a classification task.
# The routine is configured to handle OOD detection and distributional shifts using the specified model, loss function, and evaluation criteria.

routine = ClassificationRoutine(
num_classes=datamodule.num_classes,
eval_ood=True,
model=net,
loss=nn.CrossEntropyLoss(),
eval_shift=True,
ood_criterion="ash",
)

# Perform testing using the defined routine and datamodule.
results = trainer.test(model=routine, datamodule=datamodule)


# %%
# Here, we show the various test metrics along with the ood eval metrics, auroc,aupr and fpr95 on Near and far ood datasets defined per defualt according to OpenOOD splits (link to library)


# %%
# 5. Defining custom ood datasets
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# If you don't want to use the open ood datasets or dataset splits, you can pass your own datasets in a list to near_ood_datasets or far_ood_datasets datamodule arguments
# and use them for ood evaluation but make sure they inherit from the
# Dataset class from torch.utils.data, below is an example of such a case.

from torchvision.datasets import CIFAR10, MNIST
from torchvision.transforms import v2


test_transform = v2.Compose(
[
v2.ToImage(),
v2.Resize(32),
v2.CenterCrop(32),
v2.ToDtype(dtype=torch.float32, scale=True),
v2.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.5071, 0.4867, 0.4408)),
]
)

custom_dataset1 = CIFAR10(root=root, train=False, download=True, transform=test_transform)
custom_dataset2 = MNIST(root=root, train=False, download=True, transform=test_transform)

datamodule = CIFAR100DataModule(
root=root,
batch_size=200,
eval_ood=True,
eval_shift=True,
near_ood_datasets=[custom_dataset1],
far_ood_datasets=[custom_dataset2],
)

# Perform testing using the CUSTOM defined ood datasets.
results = trainer.test(model=routine, datamodule=datamodule)


# %%
# References
# ----------
# - **OpenOOD:** Jingyang Zhang & al. (`Neurips 2025 <https://arxiv.org/pdf/2306.09301>`_). OpenOOD v1.5: Enhanced Benchmark for Out-of-Distribution Detection.
153 changes: 153 additions & 0 deletions auto_tutorial_source/Classification/tutorial_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
Benchamrk bert with torch-uncertainty on SST2
===============================================

This tutorial is about using torch-uncertainty to benchmark a bert model on the sst2 dataset with various robustness metricis
and apply easily a postprocess step (MC dropout) on top either of the single model or deep ensemble.

Dataset
-------

In this tutorial we will use sst2 dataset available directly through torch uncertainty a long with various far/near ood datasets
also handled automatically by torch-uncertainty.


1. Define and load the single bert model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""

# %%
import torch
import torch.nn as nn
from collections import OrderedDict
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForSequenceClassification


def load_tu_ckpt_into_hf(
backbone, repo_id: str, filename: str, strict: bool = True, map_location="cpu"
):
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)

sd = torch.load(ckpt_path, map_location=map_location)
sd = sd.get("state_dict", sd)

def with_prefix(prefix):
return OrderedDict((k[len(prefix) :], v) for k, v in sd.items() if k.startswith(prefix))

for pref in ("model.backbone.", "model.", "backbone."):
sub = with_prefix(pref)
if sub:
return backbone.load_state_dict(sub, strict=strict)

return backbone.load_state_dict(sd, strict=strict)


class HFClassifier(nn.Module):
def __init__(self, model_name: str, num_labels: int = 2, local_files_only: bool = False):
super().__init__()
self.backbone = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, local_files_only=local_files_only
)

def forward(self, *args, **kwargs):
inputs = args[0] if (len(args) == 1 and isinstance(args[0], dict)) else kwargs
return self.backbone(**inputs).logits


net1 = HFClassifier("bert-base-uncased", num_labels=2)

load_tu_ckpt_into_hf(
net1.backbone,
repo_id="torch-uncertainty/bert-sst2",
filename="model1.ckpt",
)


# %%
# 2. Benchmark the single model
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We define first the sst2 datamodule then run the classification routine as follows.

from torch_uncertainty.routines import ClassificationRoutine
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import Sst2DataModule


dm = Sst2DataModule(
batch_size=64,
eval_ood=True,
)

trainer = TUTrainer(accelerator="gpu", enable_progress_bar=True, devices=1)

routine = ClassificationRoutine(
num_classes=2,
model=net1,
loss=nn.CrossEntropyLoss(),
eval_ood=True,
)

res = trainer.test(routine, datamodule=dm)


# %%
# 3. Apply a postprocess step on top of the single model
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Here we will be applying for example montecarlo dropout on top of the single model.
# but torch-uncertainty supports many other postprocess like temperature scaling,conformal... please refer to the documentation.

from torch_uncertainty.models import mc_dropout

mc_net = mc_dropout(
model=net1,
num_estimators=8,
on_batch=False,
)

routine = ClassificationRoutine(
num_classes=2,
model=mc_net,
loss=nn.CrossEntropyLoss(),
eval_ood=True,
)

res = trainer.test(routine, datamodule=dm)


# %%
# 4. Load and benchmark a deep ensemble of bert models
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Let us load the remaining models of the deep ensemble and then benchmark them easily with torch unceratinty.

net2 = HFClassifier("bert-base-uncased", num_labels=2)

load_tu_ckpt_into_hf(
net2.backbone,
repo_id="torch-uncertainty/bert-sst2",
filename="model2.ckpt",
)


net3 = HFClassifier("bert-base-uncased", num_labels=2)

load_tu_ckpt_into_hf(
net3.backbone,
repo_id="torch-uncertainty/bert-sst2",
filename="model3.ckpt",
)


from torch_uncertainty.models import deep_ensembles

deep = deep_ensembles([net1, net2, net3])


routine = ClassificationRoutine(
num_classes=2,
model=deep,
loss=nn.CrossEntropyLoss(),
eval_ood=True,
)
res = trainer.test(routine, datamodule=dm)
Loading