diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ccf8cae705..d2e13e7085f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added support for `dinov2` feature extractor in `FrechetInceptionDistance` ([#3186](https://github.com/Lightning-AI/torchmetrics/pull/3186)) + + - diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 49f9f057d83..33829c08506 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -21,13 +21,15 @@ from torch.nn.functional import adaptive_avg_pool2d from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE, _TORCHVISION_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE +from torchmetrics.utilities.prints import rank_zero_warn if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["FrechetInceptionDistance.plot"] if _TORCH_FIDELITY_AVAILABLE: + from torch_fidelity.feature_extractor_dinov2 import FeatureExtractorDinoV2 as _FeatureExtractorDinoV2 from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 as _FeatureExtractorInceptionV3 from torch_fidelity.helpers import vassert from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x @@ -36,6 +38,9 @@ class _FeatureExtractorInceptionV3(Module): # type: ignore[no-redef] pass + class _FeatureExtractorDinoV2(Module): # type: ignore[no-redef] + pass + vassert = None interpolate_bilinear_2d_like_tensorflow1x = None @@ -171,6 +176,92 @@ def forward(self, x: Tensor) -> Tensor: return out[0].reshape(x.shape[0], -1) +class NoTrainDinoV2(_FeatureExtractorDinoV2): + """Module that never leaves evaluation mode.""" + + def __init__( + self, + name: str, + features_list: list[str], + feature_extractor_weights_path: Optional[str] = None, + antialias: bool = True, + ) -> None: + if not _TORCH_FIDELITY_AVAILABLE: + raise ModuleNotFoundError( + "NoTrainDinoV2 module requires that `Torch-fidelity` is installed." + " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`." + ) + if not _TORCHVISION_AVAILABLE: + raise ModuleNotFoundError( + "NoTrainDinoV2 module requires that `torchvision` is installed." + " Please install with `pip install torchmetrics[image]`." + ) + + super().__init__(name, features_list, feature_extractor_weights_path) + self.use_antialias = antialias + # put into evaluation mode + self.eval() + + def train(self, mode: bool) -> "NoTrainDinoV2": + """Force network to always be in evaluation mode.""" + return super().train(False) + + def _torch_fidelity_forward(self, x: Tensor) -> tuple[Tensor, ...]: + """Forward method of dinov2 net. + + Copy of the forward method from this file: + https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/feature_extractor_dinov2.py + with a single line change regarding the casting of `x` in the beginning. + + Corresponding license file (Apache License, Version 2.0): + https://github.com/toshas/torch-fidelity/blob/master/LICENSE.md + + """ + vassert(torch.is_tensor(x) and x.dtype == torch.uint8, "Expecting image as torch.Tensor with dtype=torch.uint8") + vassert(x.dim() == 4 and x.shape[1] == 3, f"Input is not Bx3xHxW: {x.shape}") + + x = x.to(self.feature_extractor_internal_dtype) + # N x 3 x ? x ? + + if self.use_antialias: + x = torch.nn.functional.interpolate( + x, + size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE), + mode="bilinear", + align_corners=False, + antialias=True, + ) + else: + x = interpolate_bilinear_2d_like_tensorflow1x( + x, + size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE), + align_corners=False, + ) + # N x 3 x 224 x 224 + from torchvision.transforms.functional import normalize + + x = normalize( + x, + (255 * 0.485, 255 * 0.456, 255 * 0.406), + (255 * 0.229, 255 * 0.224, 255 * 0.225), + inplace=False, + ) + # N x 3 x 224 x 224 + + x = self.model(x) + + out = { + "dinov2": x.to(torch.float32), + } + + return tuple(out[a] for a in self.features_list) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of neural network with reshaping of output.""" + out = self._torch_fidelity_forward(x) + return out[0].reshape(x.shape[0], -1) + + def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor) -> Tensor: r"""Compute adjusted version of `Fid Score`_. @@ -194,6 +285,19 @@ def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor) -> Te return a + b - 2 * c +# Map feature strings to valid configurations +FEATURE_MAP = { + "inception-64": (NoTrainInceptionV3, "64", "64", (3, 299, 299)), + "inception-192": (NoTrainInceptionV3, "192", "192", (3, 299, 299)), + "inception-768": (NoTrainInceptionV3, "768", "768", (3, 299, 299)), + "inception-2048": (NoTrainInceptionV3, "2048", "2048", (3, 299, 299)), + "dino-384": (NoTrainDinoV2, "dinov2-vit-s-14", "dinov2", (3, 224, 224)), + "dino-768": (NoTrainDinoV2, "dinov2-vit-b-14", "dinov2", (3, 224, 224)), + "dino-1024": (NoTrainDinoV2, "dinov2-vit-l-14", "dinov2", (3, 224, 224)), + "dino-1536": (NoTrainDinoV2, "dinov2-vit-g-14", "dinov2", (3, 224, 224)), +} + + class FrechetInceptionDistance(Metric): r"""Calculate Fréchet inception distance (FID_) which is used to assess the quality of generated images. @@ -308,12 +412,12 @@ class FrechetInceptionDistance(Metric): fake_features_cov_sum: Tensor fake_features_num_samples: Tensor - inception: Module - feature_network: str = "inception" + feature_extractor: Module + feature_network: str = "feature_extractor" def __init__( self, - feature: Union[int, Module] = 2048, + feature: Union[int, str, Module] = "inception-2048", reset_real_features: bool = True, normalize: bool = False, input_img_size: tuple[int, int, int] = (3, 299, 299), @@ -329,42 +433,60 @@ def __init__( self.used_custom_model = False antialias = antialias - if isinstance(feature, int): - num_features = feature + if isinstance(feature, str): + if feature not in FEATURE_MAP: + raise ValueError( + f"String input to argument `feature` must be one of {list(FEATURE_MAP.keys())}, but got {feature}." + ) + feature_extractor_cls, name, feature_layer, default_img_size = FEATURE_MAP[feature] + input_img_size = default_img_size + if not _TORCH_FIDELITY_AVAILABLE: raise ModuleNotFoundError( "FrechetInceptionDistance metric requires that `Torch-fidelity` is installed." " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`." ) - valid_int_input = (64, 192, 768, 2048) - if feature not in valid_int_input: + + self.feature_extractor = feature_extractor_cls( + name=name, + features_list=[feature_layer], + feature_extractor_weights_path=feature_extractor_weights_path, + antialias=antialias, + ) + num_features = int(feature.split("-")[-1]) + elif isinstance(feature, int): + rank_zero_warn( + "Using an integer input to `feature` is deprecated and will be removed in v1.9." + "Instead, use a string input like 'inception-2048' or 'dino-768'.", + DeprecationWarning, + ) + if feature not in [64, 192, 768, 2048]: raise ValueError( - f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}." + f"Integer input to argument `feature` must be one of [64, 192, 768, 2048], but got {feature}." ) - - self.inception = NoTrainInceptionV3( - name="inception-v3-compat", + self.feature_extractor = NoTrainInceptionV3( + name=f"inception-{feature}", features_list=[str(feature)], feature_extractor_weights_path=feature_extractor_weights_path, antialias=antialias, ) - + num_features = feature elif isinstance(feature, Module): - self.inception = feature + self.feature_extractor = feature self.used_custom_model = True - if hasattr(self.inception, "num_features"): - if isinstance(self.inception.num_features, int): - num_features = self.inception.num_features - elif isinstance(self.inception.num_features, Tensor): - num_features = int(self.inception.num_features.item()) + if hasattr(self.feature_extractor, "num_features"): + if isinstance(self.feature_extractor.num_features, int): + num_features = self.feature_extractor.num_features + elif isinstance(self.feature_extractor.num_features, Tensor): + num_features = int(self.feature_extractor.num_features.item()) else: - raise TypeError("Expected `self.inception.num_features` to be of type int or Tensor.") + raise TypeError("Expected `self.feature_extractor.num_features` to be of type int or Tensor.") else: if self.normalize: dummy_image = torch.rand(1, *input_img_size, dtype=torch.float32) else: dummy_image = torch.randint(0, 255, (1, *input_img_size), dtype=torch.uint8) - num_features = self.inception(dummy_image).shape[-1] + num_features = self.feature_extractor(dummy_image).shape[-1] else: raise TypeError("Got unknown input to argument `feature`") @@ -391,7 +513,7 @@ def update(self, imgs: Tensor, real: bool) -> None: """ imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs - features = self.inception(imgs) + features = self.feature_extractor(imgs) self.orig_dtype = features.dtype features = features.double() @@ -440,8 +562,8 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric": """ out = super().set_dtype(dst_type) - if isinstance(out.inception, NoTrainInceptionV3): - out.inception._dtype = dst_type + if isinstance(out.feature_extractor, (NoTrainInceptionV3, NoTrainDinoV2)): + out.feature_extractor._dtype = dst_type return out def plot( diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index baec39f1197..3e37e9be784 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -20,14 +20,14 @@ from torch.nn import Module from torch.utils.data import Dataset -from torchmetrics.image.fid import FrechetInceptionDistance, NoTrainInceptionV3 +from torchmetrics.image.fid import FrechetInceptionDistance, NoTrainDinoV2, NoTrainInceptionV3 from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE from unittests._helpers import seed_all seed_all(42) -def test_no_train_network_missing_torch_fidelity(monkeypatch): +def test_no_train_inception_network_missing_torch_fidelity(monkeypatch): """Assert that NoTrainInceptionV3 raises an error if torch-fidelity is not installed.""" # mock/fake the import of torch-fidelity monkeypatch.setattr("torchmetrics.image.fid._TORCH_FIDELITY_AVAILABLE", False) @@ -37,6 +37,14 @@ def test_no_train_network_missing_torch_fidelity(monkeypatch): NoTrainInceptionV3(name="inception-v3-compat", features_list=["2048"]) +def test_no_train_dinov2_network_missing_torch_fidelity(monkeypatch): + """Assert that NoTrainDinoV2 raises an error if torch-fidelity is not installed.""" + # mock/fake the import of torch-fidelity + monkeypatch.setattr("torchmetrics.image.fid._TORCH_FIDELITY_AVAILABLE", False) + with pytest.raises(ModuleNotFoundError, match="NoTrainDinoV2 module requires that `Torch-fidelity` is installed.*"): + NoTrainDinoV2(name="dinov2-vit-g-14", features_list=["1536"]) + + @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_no_train(): """Assert that metric never leaves evaluation mode.""" @@ -52,7 +60,9 @@ def forward(self, x): model = MyModel() model.train() assert model.training - assert not model.metric.inception.training, "FID metric was changed to training mode which should not happen" + assert not model.metric.feature_extractor.training, ( + "FID metric was changed to training mode which should not happen" + ) @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @@ -71,11 +81,14 @@ def test_fid_raises_errors_and_warnings(): if _TORCH_FIDELITY_AVAILABLE: with pytest.raises(ValueError, match="Integer input to argument `feature` must be one of .*"): _ = FrechetInceptionDistance(feature=2) + + with pytest.raises(ValueError, match="String input to argument `feature` must be one of .*"): + _ = FrechetInceptionDistance(feature="invalid-feature") else: with pytest.raises( ModuleNotFoundError, - match="FID metric requires that `Torch-fidelity` is installed." - " Either install as `pip install torchmetrics[image-quality]` or `pip install torch-fidelity`.", + match="FrechetInceptionDistance metric requires that `Torch-fidelity` is installed." + " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`.", ): _ = FrechetInceptionDistance() @@ -95,14 +108,17 @@ def __call__(self, img) -> torch.Tensor: @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") -@pytest.mark.parametrize("feature", [64, 192, 768, 2048, _DummyFeatureExtractor()]) +@pytest.mark.parametrize("feature", [64, 192, "inception-64", "dino-384", _DummyFeatureExtractor()]) def test_fid_same_input(feature): """If real and fake are update on the same data the fid score should be 0.""" metric = FrechetInceptionDistance(feature=feature) + # Determine correct input size based on feature + img_size = (10, 3, 224, 224) if isinstance(feature, str) and feature.startswith("dino") else (10, 3, 299, 299) + seed_all(42) for _ in range(2): - img = torch.randint(0, 255, (10, 3, 299, 299), dtype=torch.uint8) + img = torch.randint(0, 255, img_size, dtype=torch.uint8) metric.update(img, real=True) metric.update(img, real=False) @@ -128,18 +144,26 @@ def __len__(self) -> int: @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("equal_size", [False, True]) -def test_compare_fid(tmpdir, equal_size, feature=768): - """Check that the hole pipeline give the same result as torch-fidelity.""" +@pytest.mark.parametrize( + "feature_config", + [ + pytest.param(("inception-64", (3, 299, 299), "inception-v3-compat", 64), id="inception-64"), + pytest.param(("dino-384", (3, 224, 224), "dinov2-vit-s-14", "dinov2"), id="dino-384"), + ], +) +def test_compare_fid(tmpdir, equal_size, feature_config): + """Check that the whole pipeline gives the same result as torch-fidelity.""" from torch_fidelity import calculate_metrics + feature, img_size, fidelity_name, fidelity_feature_list = feature_config metric = FrechetInceptionDistance(feature=feature).cuda() n, m = 100, 100 if equal_size else 90 # Generate some synthetic data seed_all(42) - img1 = torch.randint(0, 180, (n, 3, 299, 299), dtype=torch.uint8) - img2 = torch.randint(100, 255, (m, 3, 299, 299), dtype=torch.uint8) + img1 = torch.randint(0, 180, (n, *img_size), dtype=torch.uint8) + img2 = torch.randint(100, 255, (m, *img_size), dtype=torch.uint8) batch_size = 10 for i in range(n // batch_size): @@ -152,7 +176,8 @@ def test_compare_fid(tmpdir, equal_size, feature=768): input1=_ImgDataset(img1), input2=_ImgDataset(img2), fid=True, - feature_layer_fid=str(feature), + feature_extractor=fidelity_name, + feature_layer_fid=str(fidelity_feature_list), batch_size=batch_size, save_cpu_ram=True, ) @@ -163,20 +188,23 @@ def test_compare_fid(tmpdir, equal_size, feature=768): @pytest.mark.parametrize("reset_real_features", [True, False]) -def test_reset_real_features_arg(reset_real_features): +@pytest.mark.parametrize("feature_config", [("inception-64", (3, 299, 299)), ("dino-384", (3, 224, 224))]) +def test_reset_real_features_arg(reset_real_features, feature_config): """Test that `reset_real_features` argument works as expected.""" - metric = FrechetInceptionDistance(feature=64, reset_real_features=reset_real_features) + feature, img_size = feature_config + metric = FrechetInceptionDistance(feature=feature, reset_real_features=reset_real_features) + feature_dim = int(feature.split("-")[-1]) - metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=True) - metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=False) + metric.update(torch.randint(0, 180, (2, *img_size), dtype=torch.uint8), real=True) + metric.update(torch.randint(0, 180, (2, *img_size), dtype=torch.uint8), real=False) assert metric.real_features_num_samples == 2 - assert metric.real_features_sum.shape == torch.Size([64]) - assert metric.real_features_cov_sum.shape == torch.Size([64, 64]) + assert metric.real_features_sum.shape == torch.Size([feature_dim]) + assert metric.real_features_cov_sum.shape == torch.Size([feature_dim, feature_dim]) assert metric.fake_features_num_samples == 2 - assert metric.fake_features_sum.shape == torch.Size([64]) - assert metric.fake_features_cov_sum.shape == torch.Size([64, 64]) + assert metric.fake_features_sum.shape == torch.Size([feature_dim]) + assert metric.fake_features_cov_sum.shape == torch.Size([feature_dim, feature_dim]) metric.reset() @@ -187,15 +215,17 @@ def test_reset_real_features_arg(reset_real_features): assert metric.real_features_num_samples == 0 else: assert metric.real_features_num_samples == 2 - assert metric.real_features_sum.shape == torch.Size([64]) - assert metric.real_features_cov_sum.shape == torch.Size([64, 64]) + assert metric.real_features_sum.shape == torch.Size([feature_dim]) + assert metric.real_features_cov_sum.shape == torch.Size([feature_dim, feature_dim]) @pytest.mark.parametrize("normalize", [True, False]) -def test_normalize_arg(normalize): +@pytest.mark.parametrize("feature_config", [("inception-64", (3, 299, 299)), ("dino-384", (3, 224, 224))]) +def test_normalize_arg(normalize, feature_config): """Test that normalize argument works as expected.""" - img = torch.rand(2, 3, 299, 299) - metric = FrechetInceptionDistance(normalize=normalize) + feature, img_size = feature_config + img = torch.rand(2, *img_size) + metric = FrechetInceptionDistance(feature=feature, normalize=normalize) context = ( partial( @@ -222,23 +252,25 @@ def test_not_enough_samples(): def test_dtype_transfer_to_submodule(): - """Test that change in dtype also changes the default inception net.""" + """Test that change in dtype also changes the default feature extractor net.""" imgs = torch.randn(1, 3, 256, 256) imgs = ((imgs.clamp(-1, 1) / 2 + 0.5) * 255).to(torch.uint8) - metric = FrechetInceptionDistance(feature=64) + metric = FrechetInceptionDistance(feature="inception-64") metric.set_dtype(torch.float64) - out = metric.inception(imgs) + out = metric.feature_extractor(imgs) assert out.dtype == torch.float64 -def test_antialias(): +@pytest.mark.parametrize("feature_config", [("inception-64", (3, 299, 299)), ("dino-384", (3, 224, 224))]) +def test_antialias(feature_config): """Test that on random input the antialiasing should produce similar results.""" - imgs = torch.randint(0, 255, (10, 3, 299, 299), dtype=torch.uint8) + feature, img_size = feature_config + imgs = torch.randint(0, 255, (10, *img_size), dtype=torch.uint8) - metric_no_aa = FrechetInceptionDistance(feature=64, antialias=False) - metric_aa = FrechetInceptionDistance(feature=64, antialias=True) + metric_no_aa = FrechetInceptionDistance(feature=feature, antialias=False) + metric_aa = FrechetInceptionDistance(feature=feature, antialias=True) metric_no_aa.update(imgs, real=True) metric_no_aa.update(imgs, real=False)