Skip to content

Commit b48efe8

Browse files
committed
Add comprehensive tests for SSIM and EVM metrics, including stateful methods, error handling, and edge cases
1 parent b661b3d commit b48efe8

File tree

2 files changed

+828
-0
lines changed

2 files changed

+828
-0
lines changed

tests/metrics/test_image_ssim.py

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
# tests/metrics/test_image_ssim.py
2+
"""Tests for SSIM (Structural Similarity Index Measure) metrics."""
3+
import pytest
4+
import torch
5+
6+
from kaira.metrics.image.ssim import SSIM, MultiScaleSSIM, StructuralSimilarityIndexMeasure
7+
8+
9+
class TestStructuralSimilarityIndexMeasure:
10+
"""Test cases for Structural Similarity Index Measure (SSIM) metric."""
11+
12+
def test_ssim_basic_computation(self):
13+
"""Test basic SSIM computation with simple images."""
14+
metric = StructuralSimilarityIndexMeasure()
15+
16+
# Create simple test images
17+
img1 = torch.rand(1, 3, 32, 32)
18+
img2 = img1.clone()
19+
20+
ssim = metric.forward(img1, img2)
21+
assert torch.isclose(ssim, torch.tensor([1.0]), atol=1e-4), f"SSIM should be ~1.0 for identical images, got {ssim}"
22+
23+
def test_ssim_perfect_similarity(self):
24+
"""Test SSIM with identical images."""
25+
metric = StructuralSimilarityIndexMeasure()
26+
27+
img = torch.rand(2, 3, 64, 64)
28+
ssim = metric.forward(img, img)
29+
30+
assert torch.allclose(ssim, torch.ones_like(ssim), atol=1e-4), "SSIM should be 1.0 for identical images"
31+
32+
def test_ssim_different_images(self):
33+
"""Test SSIM with different images."""
34+
metric = StructuralSimilarityIndexMeasure()
35+
36+
img1 = torch.zeros(1, 3, 32, 32)
37+
img2 = torch.ones(1, 3, 32, 32)
38+
39+
ssim = metric.forward(img1, img2)
40+
assert ssim < 1.0, "SSIM should be less than 1.0 for different images"
41+
assert ssim >= 0.0, "SSIM should be non-negative"
42+
43+
def test_ssim_data_range(self):
44+
"""Test SSIM with different data ranges."""
45+
# Test with data_range=1.0 (default)
46+
metric1 = StructuralSimilarityIndexMeasure(data_range=1.0)
47+
48+
# Test with data_range=255.0
49+
metric255 = StructuralSimilarityIndexMeasure(data_range=255.0)
50+
51+
img_0_1 = torch.rand(1, 3, 32, 32) # Range [0, 1]
52+
img_0_255 = img_0_1 * 255 # Range [0, 255]
53+
54+
ssim1 = metric1.forward(img_0_1, img_0_1)
55+
ssim255 = metric255.forward(img_0_255, img_0_255)
56+
57+
assert torch.allclose(ssim1, ssim255, atol=1e-4), "SSIM should be similar regardless of data range for identical images"
58+
59+
def test_ssim_kernel_size(self):
60+
"""Test SSIM with different kernel sizes."""
61+
img1 = torch.rand(1, 3, 64, 64)
62+
img2 = torch.rand(1, 3, 64, 64)
63+
64+
for kernel_size in [7, 11, 15]:
65+
metric = StructuralSimilarityIndexMeasure(kernel_size=kernel_size)
66+
ssim = metric.forward(img1, img2)
67+
assert torch.isfinite(ssim), f"SSIM should be finite for kernel_size={kernel_size}"
68+
assert 0 <= ssim <= 1, f"SSIM should be in [0,1] for kernel_size={kernel_size}"
69+
70+
def test_ssim_sigma(self):
71+
"""Test SSIM with different sigma values."""
72+
img1 = torch.rand(1, 3, 32, 32)
73+
img2 = torch.rand(1, 3, 32, 32)
74+
75+
for sigma in [0.5, 1.0, 1.5, 2.0]:
76+
metric = StructuralSimilarityIndexMeasure(sigma=sigma)
77+
ssim = metric.forward(img1, img2)
78+
assert torch.isfinite(ssim), f"SSIM should be finite for sigma={sigma}"
79+
80+
def test_ssim_reduction_methods(self):
81+
"""Test SSIM with different reduction methods."""
82+
img1 = torch.rand(3, 3, 32, 32)
83+
img2 = torch.rand(3, 3, 32, 32)
84+
85+
# Test no reduction
86+
metric_none = StructuralSimilarityIndexMeasure(reduction=None)
87+
ssim_none = metric_none.forward(img1, img2)
88+
assert ssim_none.shape[0] == 3, "No reduction should return per-sample SSIM"
89+
90+
# Test mean reduction
91+
metric_mean = StructuralSimilarityIndexMeasure(reduction="mean")
92+
ssim_mean = metric_mean.forward(img1, img2)
93+
assert ssim_mean.numel() == 1, "Mean reduction should return scalar"
94+
95+
# Test sum reduction
96+
metric_sum = StructuralSimilarityIndexMeasure(reduction="sum")
97+
ssim_sum = metric_sum.forward(img1, img2)
98+
assert ssim_sum.numel() == 1, "Sum reduction should return scalar"
99+
100+
# Verify relationships
101+
assert torch.isclose(ssim_mean, ssim_none.mean()), "Mean reduction should equal manual mean"
102+
assert torch.isclose(ssim_sum, ssim_none.sum()), "Sum reduction should equal manual sum"
103+
104+
def test_ssim_compute_with_stats(self):
105+
"""Test SSIM compute_with_stats method."""
106+
metric = StructuralSimilarityIndexMeasure()
107+
108+
img1 = torch.rand(5, 3, 32, 32)
109+
img2 = torch.rand(5, 3, 32, 32)
110+
111+
mean_ssim, std_ssim = metric.compute_with_stats(img1, img2)
112+
113+
assert torch.isfinite(mean_ssim), "Mean SSIM should be finite"
114+
assert torch.isfinite(std_ssim), "Std SSIM should be finite"
115+
assert std_ssim >= 0, "Standard deviation should be non-negative"
116+
117+
def test_ssim_single_sample_stats(self):
118+
"""Test SSIM stats computation with single sample."""
119+
metric = StructuralSimilarityIndexMeasure()
120+
121+
img1 = torch.rand(1, 3, 32, 32)
122+
img2 = torch.rand(1, 3, 32, 32)
123+
124+
mean_ssim, std_ssim = metric.compute_with_stats(img1, img2)
125+
126+
assert torch.isfinite(mean_ssim), "Mean SSIM should be finite for single sample"
127+
assert torch.isclose(std_ssim, torch.tensor(0.0)), "Std should be 0 for single sample"
128+
129+
def test_ssim_batch_processing(self):
130+
"""Test SSIM with different batch sizes."""
131+
metric = StructuralSimilarityIndexMeasure()
132+
133+
for batch_size in [1, 2, 4, 8]:
134+
img1 = torch.rand(batch_size, 3, 32, 32)
135+
img2 = torch.rand(batch_size, 3, 32, 32)
136+
137+
ssim = metric.forward(img1, img2)
138+
assert ssim.shape[0] == batch_size, f"SSIM should have batch_size={batch_size} outputs"
139+
140+
def test_ssim_grayscale_images(self):
141+
"""Test SSIM with grayscale images."""
142+
metric = StructuralSimilarityIndexMeasure()
143+
144+
img1 = torch.rand(2, 1, 32, 32) # Grayscale
145+
img2 = torch.rand(2, 1, 32, 32)
146+
147+
ssim = metric.forward(img1, img2)
148+
assert ssim.shape[0] == 2, "SSIM should work with grayscale images"
149+
assert torch.isfinite(ssim).all(), "SSIM should be finite for grayscale images"
150+
151+
def test_ssim_different_image_sizes(self):
152+
"""Test SSIM with different image sizes."""
153+
metric = StructuralSimilarityIndexMeasure()
154+
155+
for size in [16, 32, 64, 128]:
156+
img1 = torch.rand(1, 3, size, size)
157+
img2 = torch.rand(1, 3, size, size)
158+
159+
ssim = metric.forward(img1, img2)
160+
assert torch.isfinite(ssim), f"SSIM should be finite for size {size}x{size}"
161+
162+
def test_ssim_shape_mismatch(self):
163+
"""Test SSIM with mismatched image shapes."""
164+
metric = StructuralSimilarityIndexMeasure()
165+
166+
img1 = torch.rand(1, 3, 32, 32)
167+
img2 = torch.rand(1, 3, 64, 64)
168+
169+
with pytest.raises((RuntimeError, ValueError)):
170+
metric.forward(img1, img2)
171+
172+
173+
class TestMultiScaleSSIM:
174+
"""Test cases for Multi-Scale SSIM (MS-SSIM) metric."""
175+
176+
def test_ms_ssim_basic_computation(self):
177+
"""Test basic MS-SSIM computation."""
178+
metric = MultiScaleSSIM()
179+
180+
img1 = torch.rand(1, 3, 200, 200) # MS-SSIM requires larger images (>160)
181+
img2 = img1.clone()
182+
183+
ms_ssim = metric.forward(img1, img2)
184+
assert torch.isclose(ms_ssim, torch.tensor([1.0]), atol=1e-3), f"MS-SSIM should be ~1.0 for identical images, got {ms_ssim}"
185+
186+
def test_ms_ssim_perfect_similarity(self):
187+
"""Test MS-SSIM with identical images."""
188+
metric = MultiScaleSSIM()
189+
190+
img = torch.rand(2, 3, 200, 200)
191+
ms_ssim = metric.forward(img, img)
192+
193+
assert torch.allclose(ms_ssim, torch.ones_like(ms_ssim), atol=1e-3), "MS-SSIM should be ~1.0 for identical images"
194+
195+
def test_ms_ssim_different_images(self):
196+
"""Test MS-SSIM with different images."""
197+
metric = MultiScaleSSIM()
198+
199+
img1 = torch.zeros(1, 3, 200, 200)
200+
img2 = torch.ones(1, 3, 200, 200)
201+
202+
ms_ssim = metric.forward(img1, img2)
203+
assert ms_ssim < 1.0, "MS-SSIM should be less than 1.0 for different images"
204+
assert ms_ssim >= 0.0, "MS-SSIM should be non-negative"
205+
206+
def test_ms_ssim_data_range(self):
207+
"""Test MS-SSIM with different data ranges."""
208+
metric1 = MultiScaleSSIM(data_range=1.0)
209+
metric255 = MultiScaleSSIM(data_range=255.0)
210+
211+
img_0_1 = torch.rand(1, 3, 200, 200)
212+
img_0_255 = img_0_1 * 255
213+
214+
ms_ssim1 = metric1.forward(img_0_1, img_0_1)
215+
ms_ssim255 = metric255.forward(img_0_255, img_0_255)
216+
217+
assert torch.allclose(ms_ssim1, ms_ssim255, atol=1e-3), "MS-SSIM should be similar regardless of data range"
218+
219+
def test_ms_ssim_custom_weights(self):
220+
"""Test MS-SSIM with custom weights."""
221+
weights = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])
222+
metric = MultiScaleSSIM(weights=weights)
223+
224+
img1 = torch.rand(1, 3, 200, 200)
225+
img2 = torch.rand(1, 3, 200, 200)
226+
227+
ms_ssim = metric.forward(img1, img2)
228+
assert torch.isfinite(ms_ssim), "MS-SSIM should be finite with custom weights"
229+
230+
def test_ms_ssim_reduction_methods(self):
231+
"""Test MS-SSIM with different reduction methods."""
232+
img1 = torch.rand(3, 3, 200, 200)
233+
img2 = torch.rand(3, 3, 200, 200)
234+
235+
# Test no reduction
236+
metric_none = MultiScaleSSIM(reduction=None)
237+
ms_ssim_none = metric_none.forward(img1, img2)
238+
assert ms_ssim_none.shape[0] == 3, "No reduction should return per-sample MS-SSIM"
239+
240+
# Test mean reduction
241+
metric_mean = MultiScaleSSIM(reduction="mean")
242+
ms_ssim_mean = metric_mean.forward(img1, img2)
243+
assert ms_ssim_mean.numel() == 1, "Mean reduction should return scalar"
244+
245+
# Test sum reduction
246+
metric_sum = MultiScaleSSIM(reduction="sum")
247+
ms_ssim_sum = metric_sum.forward(img1, img2)
248+
assert ms_ssim_sum.numel() == 1, "Sum reduction should return scalar"
249+
250+
def test_ms_ssim_update_compute(self):
251+
"""Test MS-SSIM update and compute methods."""
252+
metric = MultiScaleSSIM()
253+
254+
img1 = torch.rand(2, 3, 200, 200)
255+
img2 = torch.rand(2, 3, 200, 200)
256+
257+
# Test single update
258+
metric.reset()
259+
metric.update(img1, img2)
260+
mean, std = metric.compute()
261+
262+
assert torch.isfinite(mean), "Mean should be finite"
263+
assert torch.isfinite(std), "Std should be finite"
264+
assert std >= 0, "Standard deviation should be non-negative"
265+
266+
def test_ms_ssim_multiple_updates(self):
267+
"""Test MS-SSIM with multiple updates."""
268+
metric = MultiScaleSSIM()
269+
270+
metric.reset()
271+
272+
# Multiple updates
273+
for _ in range(3):
274+
img1 = torch.rand(2, 3, 200, 200)
275+
img2 = torch.rand(2, 3, 200, 200)
276+
metric.update(img1, img2)
277+
278+
mean, std = metric.compute()
279+
assert torch.isfinite(mean), "Mean should be finite after multiple updates"
280+
assert torch.isfinite(std), "Std should be finite after multiple updates"
281+
282+
def test_ms_ssim_compute_with_stats(self):
283+
"""Test MS-SSIM compute_with_stats method."""
284+
metric = MultiScaleSSIM()
285+
286+
img1 = torch.rand(4, 3, 200, 200)
287+
img2 = torch.rand(4, 3, 200, 200)
288+
289+
mean_ms_ssim, std_ms_ssim = metric.compute_with_stats(img1, img2)
290+
291+
assert torch.isfinite(mean_ms_ssim), "Mean MS-SSIM should be finite"
292+
assert torch.isfinite(std_ms_ssim), "Std MS-SSIM should be finite"
293+
assert std_ms_ssim >= 0, "Standard deviation should be non-negative"
294+
295+
def test_ms_ssim_reset(self):
296+
"""Test MS-SSIM reset functionality."""
297+
metric = MultiScaleSSIM()
298+
299+
img1 = torch.rand(2, 3, 200, 200)
300+
img2 = torch.rand(2, 3, 200, 200)
301+
302+
# Update and compute
303+
metric.update(img1, img2)
304+
mean1, std1 = metric.compute()
305+
306+
# Reset and check
307+
metric.reset()
308+
mean2, std2 = metric.compute()
309+
310+
assert torch.isclose(mean2, torch.tensor(0.0)), "Mean should be 0 after reset"
311+
assert torch.isclose(std2, torch.tensor(0.0)), "Std should be 0 after reset"
312+
313+
def test_ms_ssim_data_range_property(self):
314+
"""Test MS-SSIM data_range property."""
315+
data_range = 255.0
316+
metric = MultiScaleSSIM(data_range=data_range)
317+
318+
assert metric.data_range == data_range, f"data_range property should return {data_range}"
319+
320+
def test_ms_ssim_kernel_size(self):
321+
"""Test MS-SSIM with different kernel sizes."""
322+
# Use larger images for larger kernel sizes to satisfy torchmetrics constraints
323+
# For MS-SSIM with 5 betas and kernel_size=15, image must be > 224 pixels
324+
img1 = torch.rand(1, 3, 256, 256) # Increased from 200x200 to 256x256
325+
img2 = torch.rand(1, 3, 256, 256)
326+
327+
for kernel_size in [7, 11, 15]:
328+
metric = MultiScaleSSIM(kernel_size=kernel_size)
329+
ms_ssim = metric.forward(img1, img2)
330+
assert torch.isfinite(ms_ssim), f"MS-SSIM should be finite for kernel_size={kernel_size}"
331+
332+
def test_ms_ssim_empty_update(self):
333+
"""Test MS-SSIM update with empty tensors."""
334+
metric = MultiScaleSSIM()
335+
336+
# Create tensors that would result in empty values
337+
img1 = torch.rand(0, 3, 200, 200)
338+
img2 = torch.rand(0, 3, 200, 200)
339+
340+
metric.reset()
341+
# This should not crash, but torchmetrics may raise an error for empty tensors
342+
try:
343+
metric.update(img1, img2)
344+
mean, std = metric.compute()
345+
assert torch.isclose(mean, torch.tensor(0.0)), "Mean should be 0 for empty update"
346+
except (RuntimeError, IndexError, ValueError):
347+
# It's acceptable if this raises an error for empty tensors
348+
# The underlying torchmetrics implementation doesn't handle empty tensors well
349+
pass
350+
351+
352+
def test_ssim_alias():
353+
"""Test that SSIM alias works properly."""
354+
assert StructuralSimilarityIndexMeasure is SSIM
355+
356+
357+
def test_ssim_integration():
358+
"""Test integration between SSIM and MS-SSIM."""
359+
img1 = torch.rand(2, 3, 200, 200)
360+
img2 = img1.clone()
361+
362+
ssim_metric = StructuralSimilarityIndexMeasure()
363+
ms_ssim_metric = MultiScaleSSIM()
364+
365+
ssim_val = ssim_metric.forward(img1, img2)
366+
ms_ssim_val = ms_ssim_metric.forward(img1, img2)
367+
368+
# Both should be close to 1.0 for identical images
369+
assert torch.allclose(ssim_val, torch.ones_like(ssim_val), atol=1e-3), "SSIM should be ~1.0 for identical images"
370+
assert torch.allclose(ms_ssim_val, torch.ones_like(ms_ssim_val), atol=1e-3), "MS-SSIM should be ~1.0 for identical images"

0 commit comments

Comments
 (0)