From 57cb52914b05eb0c5502dc3820dfe536f0087ffd Mon Sep 17 00:00:00 2001 From: Bhimraj Yadav Date: Fri, 28 Nov 2025 18:52:59 +0000 Subject: [PATCH] refactor(tests): reduce input dimensions and feature parameters in FID tests to reduce execution time --- tests/unittests/audio/test_sdr.py | 8 ++++---- tests/unittests/image/test_fid.py | 2 +- tests/unittests/image/test_mifid.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unittests/audio/test_sdr.py b/tests/unittests/audio/test_sdr.py index a36bbe87c7d..7f3b35c2977 100644 --- a/tests/unittests/audio/test_sdr.py +++ b/tests/unittests/audio/test_sdr.py @@ -31,13 +31,13 @@ inputs_1spk = _Input( - preds=torch.rand(2, 1, 1, 500), - target=torch.rand(2, 1, 1, 500), + preds=torch.rand(2, 1, 1, 100), + target=torch.rand(2, 1, 1, 100), ) inputs_2spk = _Input( - preds=torch.rand(2, 1, 2, 500), - target=torch.rand(2, 1, 2, 500), + preds=torch.rand(2, 1, 2, 100), + target=torch.rand(2, 1, 2, 100), ) diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index baec39f1197..1dc256573ea 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -95,7 +95,7 @@ def __call__(self, img) -> torch.Tensor: @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") -@pytest.mark.parametrize("feature", [64, 192, 768, 2048, _DummyFeatureExtractor()]) +@pytest.mark.parametrize("feature", [64, 768, _DummyFeatureExtractor()]) def test_fid_same_input(feature): """If real and fake are update on the same data the fid score should be 0.""" metric = FrechetInceptionDistance(feature=feature) diff --git a/tests/unittests/image/test_mifid.py b/tests/unittests/image/test_mifid.py index 95d752b7ded..80d9e0e1d5a 100644 --- a/tests/unittests/image/test_mifid.py +++ b/tests/unittests/image/test_mifid.py @@ -141,7 +141,7 @@ def test_mifid_raises_errors_and_warnings(): @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") -@pytest.mark.parametrize("feature", [64, 192, 768, 2048]) +@pytest.mark.parametrize("feature", [64, 768]) def test_fid_same_input(feature): """If real and fake are update on the same data the fid score should be 0.""" metric = MemorizationInformedFrechetInceptionDistance(feature=feature)