Skip to content

Commit a65ff1d

Browse files
committed
Refactor examples and tests for improved clarity and coverage
- Removed unnecessary print statements from plot_channel_aware_base_model.py to streamline output. - Updated plot_fec_decoders_tutorial.py and plot_fec_encoders_tutorial.py by removing redundant section headers for better readability. - Added comprehensive tests for Structural Similarity Index Measure (SSIM) and Multi-Scale SSIM (MS-SSIM) metrics, ensuring various scenarios are covered. - Introduced Error Vector Magnitude (EVM) metric tests, including basic computation, edge cases, and stateful methods. - Enhanced Bit Error Rate (BER) tests to cover stateful methods and complex data handling. - Improved test coverage for Signal-to-Noise Ratio (SNR) and Block Error Rate (BLER) metrics, addressing edge cases and ensuring robustness.
1 parent b48efe8 commit a65ff1d

File tree

2 files changed

+424
-21
lines changed

2 files changed

+424
-21
lines changed

kaira/metrics/image/ssim.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ def forward(self, x: Tensor, y: Tensor, *args: Any, **kwargs: Any) -> Tensor:
6161
"""
6262
# Note: *args and **kwargs are not directly used by self.ssim call here
6363
# but are included for interface consistency.
64+
65+
# Handle empty tensors gracefully
66+
if x.numel() == 0 or y.numel() == 0:
67+
# Return empty tensor with appropriate shape
68+
batch_size = x.shape[0] if x.numel() >= 0 else 0
69+
return torch.tensor([], device=x.device, dtype=x.dtype).view(batch_size)
70+
6471
values = self.ssim(x, y)
6572

6673
# Apply reduction if specified
@@ -155,6 +162,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, *args: Any, **kwargs: Any) -
155162
# Note: *args and **kwargs are not directly used here
156163
# but are included for interface consistency.
157164

165+
# Handle empty tensors gracefully
166+
if x.numel() == 0 or y.numel() == 0:
167+
# Return empty tensor with appropriate shape
168+
batch_size = x.shape[0] if x.numel() >= 0 else 0
169+
return torch.tensor([], device=x.device, dtype=x.dtype).view(batch_size)
170+
158171
# Use torchmetrics MS-SSIM implementation
159172
values = self.ms_ssim(x, y)
160173

@@ -178,6 +191,10 @@ def update(self, preds: torch.Tensor, targets: torch.Tensor, *args: Any, **kwarg
178191
*args: Variable length argument list passed to forward.
179192
**kwargs: Arbitrary keyword arguments passed to forward.
180193
"""
194+
# Handle empty tensors gracefully
195+
if preds.numel() == 0 or targets.numel() == 0:
196+
return # Skip update for empty tensors
197+
181198
values = self.forward(preds, targets, *args, **kwargs) # Pass args/kwargs
182199
if values.numel() == 0:
183200
return # Avoid updating with empty values

0 commit comments

Comments
 (0)