Skip to content

Commit 71e48d1

Browse files
authored
Merge pull request #29 from mu373/plot-functions
Add tests to TailEstimatorSet
2 parents a6b7e99 + da980cb commit 71e48d1

File tree

3 files changed

+187
-4
lines changed

3 files changed

+187
-4
lines changed

src/tailestim/estimators/plot/plot_methods.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,8 @@ def make_diagnostic_plots(ordered_data, results, output_file_path=None, number_o
703703
axes_d[2].plot(x1_k_arr[min_k1:max_k1], n1_k_amse[min_k1:max_k1],
704704
alpha = 0.5, lw = 1.5,
705705
color = "#d55e00", label = r"$n_1$ samples")
706-
axes_d[2].scatter([h1], [n1_k_amse[np.where(x1_k_arr == h1)]], color = "#d55e00",
707-
marker = 'o', edgecolor = "black", alpha = 0.5,
706+
axes_d[2].scatter([h1], [n1_k_amse[np.where(x1_k_arr == h1)][0]], color = "#d55e00",
707+
marker = 'o', edgecolor = "black", alpha = 0.5,
708708
label = r"Min for $n_1$ sample")
709709
# plot boundary of minimization
710710
axes_d[2].axvline(max_k_index1, color = "#d55e00",
@@ -713,8 +713,8 @@ def make_diagnostic_plots(ordered_data, results, output_file_path=None, number_o
713713
axes_d[2].plot(x2_k_arr[min_k2:max_k2], n2_k_amse[min_k2:max_k2],
714714
alpha = 0.5, lw = 1.5,
715715
color = "#0072b2", label = r"$n_2$ samples")
716-
axes_d[2].scatter([h2], [n2_k_amse[np.where(x2_k_arr == h2)]], color = "#0072b2",
717-
marker = 'o', edgecolor = "black", alpha = 0.5,
716+
axes_d[2].scatter([h2], [n2_k_amse[np.where(x2_k_arr == h2)][0]], color = "#0072b2",
717+
marker = 'o', edgecolor = "black", alpha = 0.5,
718718
label = r"Min for $n_2$ sample")
719719
axes_d[2].axvline(max_k_index2, color = "#0072b2",
720720
ls = '--', alpha = 0.5,

tests/test_estimator_set.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import numpy as np
2+
import pytest
3+
import matplotlib.pyplot as plt
4+
pytestmark = [
5+
pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning"),
6+
pytest.mark.filterwarnings("ignore:divide by zero encountered in divide:RuntimeWarning")
7+
]
8+
from tailestim.estimators.estimator_set import TailEstimatorSet
9+
from tailestim.datasets import TailData
10+
11+
def test_tail_estimator_set_initialization():
12+
"""Test that TailEstimatorSet can be initialized without data."""
13+
estimator_set = TailEstimatorSet()
14+
assert estimator_set.data is None
15+
assert estimator_set.ordered_data is None
16+
assert estimator_set.results is None
17+
assert estimator_set.fig is None
18+
assert estimator_set.axes is None
19+
20+
def test_tail_estimator_set_fit():
21+
"""Test that TailEstimatorSet can fit data."""
22+
# Generate Pareto distributed data
23+
np.random.seed(42)
24+
size = 1000
25+
data = np.random.pareto(2, size)
26+
27+
# Initialize
28+
estimator_set = TailEstimatorSet()
29+
30+
# Fit data
31+
estimator_set.fit(data)
32+
33+
# Check that data was stored and processed
34+
assert estimator_set.data is not None
35+
assert estimator_set.ordered_data is not None
36+
assert estimator_set.results is not None
37+
assert len(estimator_set.data) == size
38+
assert estimator_set.ordered_data[0] >= estimator_set.ordered_data[-1] # Check ordering
39+
40+
def test_tail_estimator_set_plot():
41+
"""Test that TailEstimatorSet can generate plots."""
42+
# Generate Pareto distributed data
43+
np.random.seed(42)
44+
data = np.random.pareto(2, 1000)
45+
46+
estimator_set = TailEstimatorSet()
47+
estimator_set.fit(data)
48+
49+
# Generate plots
50+
fig, axes = estimator_set.plot()
51+
52+
# Check that plots were generated
53+
assert fig is not None
54+
assert axes is not None
55+
assert isinstance(fig, plt.Figure)
56+
assert isinstance(axes, np.ndarray)
57+
assert axes.shape == (3, 2) # 3 rows, 2 columns of plots
58+
59+
# Check that the figure and axes are stored in the object
60+
assert estimator_set.fig is fig
61+
assert estimator_set.axes is axes
62+
63+
# Clean up
64+
plt.close(fig)
65+
66+
def test_tail_estimator_set_diagnostic_plot():
67+
"""Test that TailEstimatorSet can generate diagnostic plots."""
68+
# Generate Pareto distributed data
69+
np.random.seed(42)
70+
data = np.random.pareto(2, 1000)
71+
72+
# Initialize with data and enable diagnostic plots
73+
estimator_set = TailEstimatorSet(
74+
bootstrap_flag=True,
75+
diagnostic_plots=True,
76+
# r_bootstrap=100 # Reduce bootstrap iterations for faster testing
77+
)
78+
79+
estimator_set.fit(data)
80+
81+
# Generate diagnostic plots
82+
fig_d, axes_d = estimator_set.plot_diagnostics()
83+
84+
# Check that diagnostic plots were generated
85+
assert fig_d is not None
86+
assert axes_d is not None
87+
assert isinstance(fig_d, plt.Figure)
88+
assert isinstance(axes_d, np.ndarray)
89+
90+
# Clean up
91+
plt.close(fig_d)
92+
93+
def test_tail_estimator_set_with_built_in_dataset():
94+
"""Test TailEstimatorSet with a built-in dataset."""
95+
try:
96+
# Load a built-in dataset
97+
data = TailData(name='CAIDA_KONECT').data
98+
99+
estimator_set = TailEstimatorSet(
100+
bootstrap_flag=True,
101+
diagnostic_plots=True,
102+
r_bootstrap=100 # Reduce bootstrap iterations for faster testing
103+
)
104+
105+
estimator_set.fit(data)
106+
107+
# Generate plots
108+
fig, axes = estimator_set.plot()
109+
assert fig is not None
110+
assert axes is not None
111+
plt.close(fig)
112+
113+
# Generate diagnostic plots
114+
fig_d, axes_d = estimator_set.plot_diagnostics()
115+
assert fig_d is not None
116+
assert axes_d is not None
117+
plt.close(fig_d)
118+
119+
except FileNotFoundError:
120+
pytest.skip("Built-in dataset not found, skipping test")
121+
122+
def test_tail_estimator_set_errors():
123+
"""Test that TailEstimatorSet raises appropriate errors."""
124+
# Initialize without data
125+
estimator_set = TailEstimatorSet()
126+
127+
# Attempt to plot without fitting data
128+
with pytest.raises(ValueError, match="No data has been fitted"):
129+
estimator_set.plot()
130+
131+
# Attempt to get diagnostic plots without fitting data
132+
with pytest.raises(ValueError, match="No data has been fitted"):
133+
estimator_set.plot_diagnostics()
134+
135+
# Fit data but with bootstrap disabled
136+
np.random.seed(42)
137+
data = np.random.pareto(2, 1000)
138+
estimator_set = TailEstimatorSet(bootstrap_flag=False)
139+
estimator_set.fit(data)
140+
141+
# Attempt to get diagnostic plots with bootstrap disabled
142+
with pytest.raises(ValueError, match="Diagnostic plots require bootstrap to be enabled"):
143+
estimator_set.plot_diagnostics()
144+
145+
# Fit data with bootstrap enabled but diagnostic plots disabled
146+
estimator_set = TailEstimatorSet(bootstrap_flag=True, diagnostic_plots=False)
147+
estimator_set.fit(data)
148+
149+
# Attempt to get diagnostic plots with diagnostic plots disabled
150+
with pytest.raises(ValueError, match="Diagnostic plots are not enabled"):
151+
estimator_set.plot_diagnostics()
152+
153+
def test_tail_estimator_set_parameters():
154+
"""Test that TailEstimatorSet parameters are correctly stored and retrieved."""
155+
# Generate data
156+
np.random.seed(42)
157+
size = 1000
158+
data = np.random.pareto(2, size)
159+
160+
# Initialize with custom parameters
161+
custom_bins = 50
162+
custom_r_smooth = 3
163+
custom_alpha = 0.7
164+
estimator_set = TailEstimatorSet(
165+
number_of_bins=custom_bins,
166+
r_smooth=custom_r_smooth,
167+
alpha=custom_alpha
168+
)
169+
170+
estimator_set.fit(data)
171+
172+
# Get parameters
173+
params = estimator_set.get_parameters()
174+
175+
# Check that parameters were correctly stored
176+
assert params['number_of_bins'] == custom_bins
177+
assert params['r_smooth'] == custom_r_smooth
178+
assert params['alpha'] == custom_alpha
179+
assert params['data_length'] == size

tests/test_tail_methods.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import numpy as np
22
import pytest
3+
pytestmark = [
4+
pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning"),
5+
pytest.mark.filterwarnings("ignore:divide by zero encountered in divide:RuntimeWarning")
6+
]
37
from tailestim.estimators.tail_methods import (
48
add_uniform_noise,
59
get_distribution,

0 commit comments

Comments
 (0)