Skip to content

Commit 22fd988

Browse files
committed
feat: Add Error Vector Magnitude (EVM) metric and corresponding tests
1 parent 93e2ecc commit 22fd988

File tree

3 files changed

+537
-0
lines changed

3 files changed

+537
-0
lines changed

kaira/metrics/signal/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .ber import BER, BitErrorRate
77
from .bler import BLER, FER, SER, BlockErrorRate, FrameErrorRate, SymbolErrorRate
8+
from .evm import EVM, ErrorVectorMagnitude
89
from .snr import SNR, SignalToNoiseRatio
910

1011
__all__ = [
@@ -18,4 +19,6 @@
1819
"FER",
1920
"SymbolErrorRate",
2021
"SER",
22+
"ErrorVectorMagnitude",
23+
"EVM",
2124
]

kaira/metrics/signal/evm.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""Error Vector Magnitude (EVM) metric.
2+
3+
EVM is a key performance indicator used in digital communication systems to quantify the difference
4+
between the ideal transmitted signal and the received signal. It provides a comprehensive measure
5+
of signal quality by considering both magnitude and phase errors.
6+
"""
7+
8+
from typing import Any, Optional
9+
10+
import torch
11+
from torch import Tensor
12+
13+
from ..base import BaseMetric
14+
from ..registry import MetricRegistry
15+
16+
17+
@MetricRegistry.register_metric("evm")
18+
class ErrorVectorMagnitude(BaseMetric):
19+
"""Error Vector Magnitude (EVM) metric.
20+
21+
EVM measures the difference between the ideal constellation points and the received
22+
constellation points, expressed as a percentage. It captures both magnitude and phase
23+
errors in the received signal. Lower EVM values indicate better signal quality.
24+
25+
EVM is calculated as:
26+
EVM(%) = sqrt(E[|error_vector|^2] / E[|reference_vector|^2]) * 100
27+
28+
where error_vector = received_signal - reference_signal
29+
30+
Attributes:
31+
normalize (bool): Whether to normalize by reference signal power (default: True).
32+
mode (str): EVM calculation mode ('rms', 'peak', or 'percentile').
33+
percentile (float): Percentile value when mode is 'percentile' (default: 95.0).
34+
"""
35+
36+
is_differentiable = True
37+
higher_is_better = False
38+
39+
def __init__(self, normalize: bool = True, mode: str = "rms", percentile: float = 95.0, name: Optional[str] = None, *args: Any, **kwargs: Any):
40+
"""Initialize the EVM metric.
41+
42+
Args:
43+
normalize (bool): Whether to normalize by reference signal power (default: True).
44+
mode (str): EVM calculation mode ('rms', 'peak', or 'percentile').
45+
percentile (float): Percentile value when mode is 'percentile' (default: 95.0).
46+
name (Optional[str]): Optional name for the metric.
47+
*args: Variable length argument list passed to the base class.
48+
**kwargs: Arbitrary keyword arguments passed to the base class.
49+
"""
50+
super().__init__(name=name or "EVM")
51+
self.normalize = normalize
52+
self.mode = mode.lower()
53+
self.percentile = percentile
54+
55+
if self.mode not in ["rms", "peak", "percentile"]:
56+
raise ValueError(f"Mode must be 'rms', 'peak', or 'percentile', got '{mode}'")
57+
58+
if not 0 < percentile <= 100:
59+
raise ValueError(f"Percentile must be between 0 and 100, got {percentile}")
60+
61+
def forward(self, x: Tensor, y: Tensor, *args: Any, **kwargs: Any) -> Tensor:
62+
"""Compute the Error Vector Magnitude for the current batch.
63+
64+
Args:
65+
x (Tensor): The transmitted/reference signal tensor.
66+
y (Tensor): The received signal tensor.
67+
*args: Variable length argument list (unused).
68+
**kwargs: Arbitrary keyword arguments (unused).
69+
70+
Returns:
71+
Tensor: Error Vector Magnitude as a percentage.
72+
"""
73+
if x.shape != y.shape:
74+
raise ValueError(f"Input shapes must match: {x.shape} vs {y.shape}")
75+
76+
# Handle empty tensors
77+
if x.numel() == 0:
78+
return torch.tensor(0.0, dtype=torch.float32, device=x.device)
79+
80+
# Calculate error vector
81+
error_vector = y - x
82+
83+
# Calculate error power (squared magnitude)
84+
error_power = torch.abs(error_vector) ** 2
85+
86+
if self.normalize:
87+
# Calculate reference power
88+
reference_power = torch.abs(x) ** 2
89+
90+
# Avoid division by zero
91+
reference_power = torch.clamp(reference_power, min=1e-12)
92+
93+
# Normalize error power by reference power
94+
normalized_error = error_power / reference_power
95+
else:
96+
normalized_error = error_power
97+
98+
# Calculate EVM based on mode
99+
if self.mode == "rms":
100+
# RMS EVM
101+
evm_squared = torch.mean(normalized_error)
102+
evm = torch.sqrt(evm_squared)
103+
elif self.mode == "peak":
104+
# Peak EVM
105+
evm_squared = torch.max(normalized_error)
106+
evm = torch.sqrt(evm_squared)
107+
elif self.mode == "percentile":
108+
# Percentile EVM
109+
evm_squared = torch.quantile(normalized_error.flatten(), self.percentile / 100.0)
110+
evm = torch.sqrt(evm_squared)
111+
112+
# Convert to percentage
113+
evm_percent = evm * 100.0
114+
115+
return evm_percent
116+
117+
def calculate_per_symbol_evm(self, x: Tensor, y: Tensor) -> Tensor:
118+
"""Calculate EVM for each symbol separately.
119+
120+
Args:
121+
x (Tensor): The transmitted/reference signal tensor.
122+
y (Tensor): The received signal tensor.
123+
124+
Returns:
125+
Tensor: Per-symbol EVM values as percentages.
126+
"""
127+
if x.shape != y.shape:
128+
raise ValueError(f"Input shapes must match: {x.shape} vs {y.shape}")
129+
130+
# Handle empty tensors
131+
if x.numel() == 0:
132+
return torch.tensor([], dtype=torch.float32, device=x.device)
133+
134+
# Calculate error vector
135+
error_vector = y - x
136+
137+
# Calculate per-symbol error magnitude
138+
error_magnitude = torch.abs(error_vector)
139+
140+
if self.normalize:
141+
# Calculate per-symbol reference magnitude
142+
reference_magnitude = torch.abs(x)
143+
reference_magnitude = torch.clamp(reference_magnitude, min=1e-12)
144+
145+
# Normalize by reference magnitude
146+
per_symbol_evm = error_magnitude / reference_magnitude
147+
else:
148+
per_symbol_evm = error_magnitude
149+
150+
# Convert to percentage
151+
per_symbol_evm_percent = per_symbol_evm * 100.0
152+
153+
return per_symbol_evm_percent
154+
155+
def calculate_statistics(self, x: Tensor, y: Tensor) -> dict:
156+
"""Calculate comprehensive EVM statistics.
157+
158+
Args:
159+
x (Tensor): The transmitted/reference signal tensor.
160+
y (Tensor): The received signal tensor.
161+
162+
Returns:
163+
dict: Dictionary containing various EVM statistics.
164+
"""
165+
# Calculate per-symbol EVM
166+
per_symbol_evm = self.calculate_per_symbol_evm(x, y)
167+
168+
# Calculate various statistics
169+
stats_dict = {
170+
"evm_rms": self.forward(x, y),
171+
"evm_mean": torch.mean(per_symbol_evm),
172+
"evm_std": torch.std(per_symbol_evm),
173+
"evm_min": torch.min(per_symbol_evm),
174+
"evm_max": torch.max(per_symbol_evm),
175+
"evm_median": torch.median(per_symbol_evm),
176+
"evm_95th": torch.quantile(per_symbol_evm.flatten(), 0.95),
177+
"evm_99th": torch.quantile(per_symbol_evm.flatten(), 0.99),
178+
"evm_per_symbol": per_symbol_evm,
179+
}
180+
181+
return stats_dict
182+
183+
184+
# Alias for backward compatibility
185+
EVM = ErrorVectorMagnitude

0 commit comments

Comments
 (0)