Skip to content

Commit 622949f

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #305 from IFCA-Advanced-Computing/feature-energy-distance
Add Energy distance data drift method
2 parents 6183197 + 19c87e7 commit 622949f

File tree

8 files changed

+97
-3
lines changed

8 files changed

+97
-3
lines changed

README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,9 @@ The currently implemented detectors are listed in the following table.
335335
<td style="text-align: center; border: 1px solid grey; padding: 8px;"><a href="https://doi.org/10.1007/978-3-540-75488-6_27">Nishida and Yamauchi (2007)</a></td>
336336
</tr>
337337
<tr>
338-
<td rowspan="16" style="text-align: center; border: 1px solid grey; padding: 8px;">Data drift</td>
339-
<td rowspan="14" style="text-align: center; border: 1px solid grey; padding: 8px;">Batch</td>
340-
<td rowspan="9" style="text-align: center; border: 1px solid grey; padding: 8px;">Distance based</td>
338+
<td rowspan="17" style="text-align: center; border: 1px solid grey; padding: 8px;">Data drift</td>
339+
<td rowspan="15" style="text-align: center; border: 1px solid grey; padding: 8px;">Batch</td>
340+
<td rowspan="10" style="text-align: center; border: 1px solid grey; padding: 8px;">Distance based</td>
341341
<td style="text-align: center; border: 1px solid grey; padding: 8px;">U</td>
342342
<td style="text-align: center; border: 1px solid grey; padding: 8px;">N</td>
343343
<td style="text-align: center; border: 1px solid grey; padding: 8px;">Anderson-Darling test</td>
@@ -355,6 +355,12 @@ The currently implemented detectors are listed in the following table.
355355
<td style="text-align: center; border: 1px solid grey; padding: 8px;">Earth Mover's distance</td>
356356
<td style="text-align: center; border: 1px solid grey; padding: 8px;"><a href="https://doi.org/10.1023/A:1026543900054">Rubner et al. (2000)</a></td>
357357
</tr>
358+
<tr>
359+
<td style="text-align: center; border: 1px solid grey; padding: 8px;">U</td>
360+
<td style="text-align: center; border: 1px solid grey; padding: 8px;">N</td>
361+
<td style="text-align: center; border: 1px solid grey; padding: 8px;">Energy distance</td>
362+
<td style="text-align: center; border: 1px solid grey; padding: 8px;"><a href="https://doi.org/10.1016/j.jspi.2013.03.018">Székely et al. (2013)</a></td>
363+
</tr>
358364
<tr>
359365
<td style="text-align: center; border: 1px solid grey; padding: 8px;">U</td>
360366
<td style="text-align: center; border: 1px solid grey; padding: 8px;">N</td>

docs/source/api_reference/detectors/data_drift/batch.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ The {mod}`frouros.detectors.data_drift.batch` module contains batch data drift d
2626
2727
BhattacharyyaDistance
2828
EMD
29+
EnergyDistance
2930
HellingerDistance
3031
HINormalizedComplement
3132
JS

frouros/detectors/data_drift/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ChiSquareTest,
77
CVMTest,
88
EMD,
9+
EnergyDistance,
910
HellingerDistance,
1011
HINormalizedComplement,
1112
JS,
@@ -25,6 +26,7 @@
2526
"ChiSquareTest",
2627
"CVMTest",
2728
"EMD",
29+
"EnergyDistance",
2830
"HellingerDistance",
2931
"HINormalizedComplement",
3032
"IncrementalKSTest",

frouros/detectors/data_drift/batch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .distance_based import (
44
BhattacharyyaDistance,
55
EMD,
6+
EnergyDistance,
67
HellingerDistance,
78
HINormalizedComplement,
89
JS,
@@ -25,6 +26,7 @@
2526
"ChiSquareTest",
2627
"CVMTest",
2728
"EMD",
29+
"EnergyDistance",
2830
"HellingerDistance",
2931
"HINormalizedComplement",
3032
"JS",

frouros/detectors/data_drift/batch/distance_based/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .bhattacharyya_distance import BhattacharyyaDistance
44
from .emd import EMD
5+
from .energy_distance import EnergyDistance
56
from .hellinger_distance import HellingerDistance
67
from .hi_normalized_complement import HINormalizedComplement
78
from .js import JS
@@ -12,6 +13,7 @@
1213
__all__ = [
1314
"BhattacharyyaDistance",
1415
"EMD",
16+
"EnergyDistance",
1517
"HellingerDistance",
1618
"HINormalizedComplement",
1719
"JS",
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Energy Distance module."""
2+
3+
from typing import Optional, Union
4+
5+
import numpy as np # type: ignore
6+
from scipy.stats import energy_distance # type: ignore
7+
8+
from frouros.callbacks.batch.base import BaseCallbackBatch
9+
from frouros.detectors.data_drift.base import UnivariateData
10+
from frouros.detectors.data_drift.batch.distance_based.base import (
11+
BaseDistanceBased,
12+
DistanceResult,
13+
)
14+
15+
16+
class EnergyDistance(BaseDistanceBased):
17+
"""EnergyDistance [szekely2013energy]_ detector.
18+
19+
:param callbacks: callbacks, defaults to None
20+
:type callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]]
21+
:param kwargs: additional keyword arguments to pass to scipy.stats.energy_distance
22+
:type kwargs: Dict[str, Any]
23+
24+
:References:
25+
26+
.. [szekely2013energy] Székely, Gábor J., and Maria L. Rizzo.
27+
"Energy statistics: A class of statistics based on distances."
28+
Journal of statistical planning and inference 143.8 (2013): 1249-1272.
29+
30+
:Example:
31+
32+
>>> from frouros.detectors.data_drift import EnergyDistance
33+
>>> import numpy as np
34+
>>> np.random.seed(seed=31)
35+
>>> X = np.random.normal(loc=0, scale=1, size=100)
36+
>>> Y = np.random.normal(loc=1, scale=1, size=100)
37+
>>> detector = EnergyDistance()
38+
>>> _ = detector.fit(X=X)
39+
>>> detector.compare(X=Y)[0]
40+
DistanceResult(distance=0.8359206395514527)
41+
""" # noqa: E501
42+
43+
def __init__( # noqa: D107
44+
self,
45+
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
46+
**kwargs,
47+
) -> None:
48+
super().__init__(
49+
statistical_type=UnivariateData(),
50+
statistical_method=self._energy_distance,
51+
statistical_kwargs=kwargs,
52+
callbacks=callbacks,
53+
)
54+
self.kwargs = kwargs
55+
56+
def _distance_measure(
57+
self,
58+
X_ref: np.ndarray, # noqa: N803
59+
X: np.ndarray, # noqa: N803
60+
**kwargs,
61+
) -> DistanceResult:
62+
emd = self._energy_distance(X=X_ref, Y=X, **self.kwargs)
63+
distance = DistanceResult(distance=emd)
64+
return distance
65+
66+
@staticmethod
67+
def _energy_distance(
68+
X: np.ndarray, # noqa: N803
69+
Y: np.ndarray,
70+
**kwargs,
71+
) -> float:
72+
energy = energy_distance(
73+
u_values=X.flatten(),
74+
v_values=Y.flatten(),
75+
**kwargs,
76+
)
77+
return energy

frouros/tests/integration/test_callback.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
BhattacharyyaDistance,
3131
CVMTest,
3232
EMD,
33+
EnergyDistance,
3334
HellingerDistance,
3435
HINormalizedComplement,
3536
JS,
@@ -48,6 +49,7 @@
4849
[
4950
(BhattacharyyaDistance, 0.55516059, 0.0),
5051
(EMD, 3.85346006, 0.0),
52+
(EnergyDistance, 2.11059982, 0.0),
5153
(HellingerDistance, 0.74509099, 0.0),
5254
(HINormalizedComplement, 0.78, 0.0),
5355
(JS, 0.67010107, 0.0),

frouros/tests/integration/test_data_drift.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from frouros.detectors.data_drift.batch import (
99
BhattacharyyaDistance,
1010
EMD,
11+
EnergyDistance,
1112
HellingerDistance,
1213
HINormalizedComplement,
1314
PSI,
@@ -64,6 +65,7 @@ def test_batch_distance_based_categorical(
6465
"detector, expected_distance",
6566
[
6667
(EMD(), 3.85346006),
68+
(EnergyDistance(), 2.11059982),
6769
(JS(), 0.67010107),
6870
(KL(), np.inf),
6971
(HINormalizedComplement(), 0.78),

0 commit comments

Comments
 (0)