Skip to content
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added support for `dinov2` feature extractor in `FrechetInceptionDistance` ([#3186](https://github.com/Lightning-AI/torchmetrics/pull/3186))


### Changed
Expand Down
170 changes: 146 additions & 24 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from torch.nn.functional import adaptive_avg_pool2d

from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.utilities.prints import rank_zero_warn

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["FrechetInceptionDistance.plot"]

if _TORCH_FIDELITY_AVAILABLE:
from torch_fidelity.feature_extractor_dinov2 import FeatureExtractorDinoV2 as _FeatureExtractorDinoV2
from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 as _FeatureExtractorInceptionV3
from torch_fidelity.helpers import vassert
from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
Expand All @@ -36,6 +38,9 @@
class _FeatureExtractorInceptionV3(Module): # type: ignore[no-redef]
pass

class _FeatureExtractorDinoV2(Module): # type: ignore[no-redef]
pass

vassert = None
interpolate_bilinear_2d_like_tensorflow1x = None

Expand Down Expand Up @@ -171,6 +176,92 @@ def forward(self, x: Tensor) -> Tensor:
return out[0].reshape(x.shape[0], -1)


class NoTrainDinoV2(_FeatureExtractorDinoV2):
"""Module that never leaves evaluation mode."""

def __init__(
self,
name: str,
features_list: list[str],
feature_extractor_weights_path: Optional[str] = None,
antialias: bool = True,
) -> None:
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"NoTrainDinoV2 module requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(
"NoTrainDinoV2 module requires that `torchvision` is installed."
" Please install with `pip install torchmetrics[image]`."
)

super().__init__(name, features_list, feature_extractor_weights_path)
self.use_antialias = antialias
# put into evaluation mode
self.eval()

def train(self, mode: bool) -> "NoTrainDinoV2":
"""Force network to always be in evaluation mode."""
return super().train(False)

def _torch_fidelity_forward(self, x: Tensor) -> tuple[Tensor, ...]:
"""Forward method of dinov2 net.

Copy of the forward method from this file:
https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/feature_extractor_dinov2.py
with a single line change regarding the casting of `x` in the beginning.

Corresponding license file (Apache License, Version 2.0):
https://github.com/toshas/torch-fidelity/blob/master/LICENSE.md

"""
vassert(torch.is_tensor(x) and x.dtype == torch.uint8, "Expecting image as torch.Tensor with dtype=torch.uint8")
vassert(x.dim() == 4 and x.shape[1] == 3, f"Input is not Bx3xHxW: {x.shape}")

x = x.to(self.feature_extractor_internal_dtype)
# N x 3 x ? x ?

if self.use_antialias:
x = torch.nn.functional.interpolate(
x,
size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
mode="bilinear",
align_corners=False,
antialias=True,
)
else:
x = interpolate_bilinear_2d_like_tensorflow1x(
x,
size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
align_corners=False,
)
# N x 3 x 224 x 224
from torchvision.transforms.functional import normalize

x = normalize(
x,
(255 * 0.485, 255 * 0.456, 255 * 0.406),
(255 * 0.229, 255 * 0.224, 255 * 0.225),
inplace=False,
)
# N x 3 x 224 x 224

x = self.model(x)

out = {
"dinov2": x.to(torch.float32),
}

return tuple(out[a] for a in self.features_list)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass of neural network with reshaping of output."""
out = self._torch_fidelity_forward(x)
return out[0].reshape(x.shape[0], -1)


def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor) -> Tensor:
r"""Compute adjusted version of `Fid Score`_.

Expand All @@ -194,6 +285,19 @@ def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor) -> Te
return a + b - 2 * c


# Map feature strings to valid configurations
FEATURE_MAP = {
"inception-64": (NoTrainInceptionV3, "64", "64", (3, 299, 299)),
"inception-192": (NoTrainInceptionV3, "192", "192", (3, 299, 299)),
"inception-768": (NoTrainInceptionV3, "768", "768", (3, 299, 299)),
"inception-2048": (NoTrainInceptionV3, "2048", "2048", (3, 299, 299)),
"dino-384": (NoTrainDinoV2, "dinov2-vit-s-14", "dinov2", (3, 224, 224)),
"dino-768": (NoTrainDinoV2, "dinov2-vit-b-14", "dinov2", (3, 224, 224)),
"dino-1024": (NoTrainDinoV2, "dinov2-vit-l-14", "dinov2", (3, 224, 224)),
"dino-1536": (NoTrainDinoV2, "dinov2-vit-g-14", "dinov2", (3, 224, 224)),
}


class FrechetInceptionDistance(Metric):
r"""Calculate Fréchet inception distance (FID_) which is used to assess the quality of generated images.

Expand Down Expand Up @@ -308,12 +412,12 @@ class FrechetInceptionDistance(Metric):
fake_features_cov_sum: Tensor
fake_features_num_samples: Tensor

inception: Module
feature_network: str = "inception"
feature_extractor: Module
feature_network: str = "feature_extractor"

def __init__(
self,
feature: Union[int, Module] = 2048,
feature: Union[int, str, Module] = "inception-2048",
reset_real_features: bool = True,
normalize: bool = False,
input_img_size: tuple[int, int, int] = (3, 299, 299),
Expand All @@ -329,42 +433,60 @@ def __init__(
self.used_custom_model = False
antialias = antialias

if isinstance(feature, int):
num_features = feature
if isinstance(feature, str):
if feature not in FEATURE_MAP:
raise ValueError(
f"String input to argument `feature` must be one of {list(FEATURE_MAP.keys())}, but got {feature}."
)
feature_extractor_cls, name, feature_layer, default_img_size = FEATURE_MAP[feature]
input_img_size = default_img_size

if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"FrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
valid_int_input = (64, 192, 768, 2048)
if feature not in valid_int_input:

self.feature_extractor = feature_extractor_cls(
name=name,
features_list=[feature_layer],
feature_extractor_weights_path=feature_extractor_weights_path,
antialias=antialias,
)
num_features = int(feature.split("-")[-1])
elif isinstance(feature, int):
rank_zero_warn(
"Using an integer input to `feature` is deprecated and will be removed in v1.9."
"Instead, use a string input like 'inception-2048' or 'dino-768'.",
DeprecationWarning,
)
if feature not in [64, 192, 768, 2048]:
raise ValueError(
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
f"Integer input to argument `feature` must be one of [64, 192, 768, 2048], but got {feature}."
)

self.inception = NoTrainInceptionV3(
name="inception-v3-compat",
self.feature_extractor = NoTrainInceptionV3(
name=f"inception-{feature}",
features_list=[str(feature)],
feature_extractor_weights_path=feature_extractor_weights_path,
antialias=antialias,
)

num_features = feature
elif isinstance(feature, Module):
self.inception = feature
self.feature_extractor = feature
self.used_custom_model = True
if hasattr(self.inception, "num_features"):
if isinstance(self.inception.num_features, int):
num_features = self.inception.num_features
elif isinstance(self.inception.num_features, Tensor):
num_features = int(self.inception.num_features.item())
if hasattr(self.feature_extractor, "num_features"):
if isinstance(self.feature_extractor.num_features, int):
num_features = self.feature_extractor.num_features
elif isinstance(self.feature_extractor.num_features, Tensor):
num_features = int(self.feature_extractor.num_features.item())
else:
raise TypeError("Expected `self.inception.num_features` to be of type int or Tensor.")
raise TypeError("Expected `self.feature_extractor.num_features` to be of type int or Tensor.")
else:
if self.normalize:
dummy_image = torch.rand(1, *input_img_size, dtype=torch.float32)
else:
dummy_image = torch.randint(0, 255, (1, *input_img_size), dtype=torch.uint8)
num_features = self.inception(dummy_image).shape[-1]
num_features = self.feature_extractor(dummy_image).shape[-1]
else:
raise TypeError("Got unknown input to argument `feature`")

Expand All @@ -391,7 +513,7 @@ def update(self, imgs: Tensor, real: bool) -> None:

"""
imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs
features = self.inception(imgs)
features = self.feature_extractor(imgs)
self.orig_dtype = features.dtype
features = features.double()

Expand Down Expand Up @@ -440,8 +562,8 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric":

"""
out = super().set_dtype(dst_type)
if isinstance(out.inception, NoTrainInceptionV3):
out.inception._dtype = dst_type
if isinstance(out.feature_extractor, (NoTrainInceptionV3, NoTrainDinoV2)):
out.feature_extractor._dtype = dst_type
return out

def plot(
Expand Down
Loading
Loading