Skip to content

Commit b0d4386

Browse files
committed
test(families): tests for exponential family
1 parent d1a84a1 commit b0d4386

File tree

1 file changed

+255
-0
lines changed

1 file changed

+255
-0
lines changed
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
"""
2+
Tests for Exponential Distribution Family
3+
4+
This module tests the functionality of the exponential distribution family,
5+
including parameterizations, characteristics, and sampling.
6+
"""
7+
8+
__author__ = "Fedor Myznikov"
9+
__copyright__ = "Copyright (c) 2025 PySATL project"
10+
__license__ = "SPDX-License-Identifier: MIT"
11+
12+
13+
import numpy as np
14+
import pytest
15+
from scipy.stats import expon
16+
17+
from pysatl_core.distributions.support import ContinuousSupport
18+
from pysatl_core.families.configuration import configure_families_register
19+
from pysatl_core.types import (
20+
CharacteristicName,
21+
ContinuousSupportShape1D,
22+
FamilyName,
23+
UnivariateContinuous,
24+
)
25+
26+
from .base import BaseDistributionTest
27+
28+
29+
class TestExponentialFamily(BaseDistributionTest):
30+
"""Test suite for Exponential distribution family."""
31+
32+
def setup_method(self):
33+
"""Setup before each test method."""
34+
registry = configure_families_register()
35+
self.exponential_family = registry.get(FamilyName.EXPONENTIAL)
36+
self.exponential_dist_example = self.exponential_family(lambda_=0.5)
37+
38+
def test_family_properties(self):
39+
"""Test basic properties of exponential family."""
40+
assert self.exponential_family.name == FamilyName.EXPONENTIAL
41+
42+
# Check parameterizations
43+
expected_parametrizations = {"rate", "scale"}
44+
assert set(self.exponential_family.parametrization_names) == expected_parametrizations
45+
assert self.exponential_family.base_parametrization_name == "rate"
46+
47+
def test_rate_parametrization_creation(self):
48+
"""Test creation of distribution with rate parametrization."""
49+
dist = self.exponential_family(lambda_=0.5)
50+
51+
assert dist.family_name == FamilyName.EXPONENTIAL
52+
assert dist.distribution_type == UnivariateContinuous
53+
assert dist.parameters == {"lambda_": 0.5}
54+
assert dist.parametrization_name == "rate"
55+
56+
def test_scale_parametrization_creation(self):
57+
"""Test creation of distribution with scale parametrization."""
58+
dist = self.exponential_family(beta=2.0, parametrization_name="scale")
59+
60+
assert dist.parameters == {"beta": 2.0}
61+
assert dist.parametrization_name == "scale"
62+
63+
def test_parametrization_constraints(self):
64+
"""Test parameter constraints validation."""
65+
# lambda_ must be positive
66+
with pytest.raises(ValueError, match="lambda_ > 0"):
67+
self.exponential_family(lambda_=-1.0)
68+
69+
# beta must be positive
70+
with pytest.raises(ValueError, match="beta > 0"):
71+
self.exponential_family(beta=0.0, parametrization_name="scale")
72+
73+
def test_moments(self):
74+
"""Test moment calculations."""
75+
# Mean
76+
mean_func = self.exponential_dist_example.query_method(CharacteristicName.MEAN)
77+
assert abs(mean_func(None) - 2.0) < self.CALCULATION_PRECISION
78+
79+
# Variance
80+
var_func = self.exponential_dist_example.query_method(CharacteristicName.VAR)
81+
assert abs(var_func(None) - 4.0) < self.CALCULATION_PRECISION
82+
83+
# Skewness
84+
skew_func = self.exponential_dist_example.query_method(CharacteristicName.SKEW)
85+
assert abs(skew_func(None) - 2.0) < self.CALCULATION_PRECISION
86+
87+
def test_kurtosis_calculation(self):
88+
"""Test kurtosis calculation with excess parameter."""
89+
kurt_func = self.exponential_dist_example.query_method(CharacteristicName.KURT)
90+
91+
raw_kurt = kurt_func(None)
92+
assert abs(raw_kurt - 9.0) < self.CALCULATION_PRECISION
93+
94+
excess_kurt = kurt_func(None, excess=True)
95+
assert abs(excess_kurt - 6.0) < self.CALCULATION_PRECISION
96+
97+
raw_kurt_explicit = kurt_func(None, excess=False)
98+
assert abs(raw_kurt_explicit - 9.0) < self.CALCULATION_PRECISION
99+
100+
@pytest.mark.parametrize(
101+
"parametrization_name, params, expected_lambda",
102+
[
103+
("rate", {"lambda_": 0.5}, 0.5),
104+
("scale", {"beta": 2.0}, 0.5), # lambda = 1/beta = 0.5
105+
],
106+
)
107+
def test_parametrization_conversions(self, parametrization_name, params, expected_lambda):
108+
"""Test conversions between different parameterizations."""
109+
base_params = self.exponential_family.to_base(
110+
self.exponential_family.get_parametrization(parametrization_name)(**params)
111+
)
112+
113+
assert abs(base_params.parameters["lambda_"] - expected_lambda) < self.CALCULATION_PRECISION
114+
115+
def test_analytical_computations_availability(self):
116+
"""Test that analytical computations are available for exponential distribution."""
117+
comp = self.exponential_family(lambda_=1.0).analytical_computations
118+
119+
expected_chars = {
120+
CharacteristicName.PDF,
121+
CharacteristicName.CDF,
122+
CharacteristicName.PPF,
123+
CharacteristicName.CF,
124+
CharacteristicName.MEAN,
125+
CharacteristicName.VAR,
126+
CharacteristicName.SKEW,
127+
CharacteristicName.KURT,
128+
}
129+
assert set(comp.keys()) == expected_chars
130+
131+
@pytest.mark.parametrize(
132+
"char_name, test_data, scipy_func, scipy_kwargs",
133+
[
134+
(CharacteristicName.PDF, [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], expon.pdf, {"scale": 2.0}),
135+
(CharacteristicName.CDF, [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], expon.cdf, {"scale": 2.0}),
136+
(
137+
CharacteristicName.PPF,
138+
[0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99, 0.999],
139+
expon.ppf,
140+
{"scale": 2.0},
141+
),
142+
],
143+
)
144+
def test_array_input_for_characteristics(self, char_name, test_data, scipy_func, scipy_kwargs):
145+
"""Test that characteristics support array inputs."""
146+
dist = self.exponential_dist_example
147+
char_func = dist.query_method(char_name)
148+
149+
input_array = np.array(test_data)
150+
result_array = char_func(input_array)
151+
152+
assert result_array.shape == input_array.shape
153+
154+
expected_array = scipy_func(input_array, **scipy_kwargs)
155+
156+
self.assert_arrays_almost_equal(result_array, expected_array)
157+
158+
def test_characteristic_function_array_input(self):
159+
"""Test characteristic function calculation with array input."""
160+
char_func = self.exponential_dist_example.query_method(CharacteristicName.CF)
161+
t_array = np.array([-2.0, -1.0, 0.0, 1.0, 2.0])
162+
163+
cf_array = char_func(t_array)
164+
assert cf_array.shape == t_array.shape
165+
166+
lambda_ = 0.5
167+
denominator = lambda_**2 + t_array**2
168+
expected_real = lambda_**2 / denominator
169+
expected_imag = lambda_ * t_array / denominator
170+
171+
expected_real = np.where(np.abs(t_array) < self.CALCULATION_PRECISION, 1.0, expected_real)
172+
expected_imag = np.where(np.abs(t_array) < self.CALCULATION_PRECISION, 0.0, expected_imag)
173+
174+
expected = expected_real + 1j * expected_imag
175+
176+
self.assert_arrays_almost_equal(cf_array.real, expected.real)
177+
self.assert_arrays_almost_equal(cf_array.imag, expected.imag)
178+
179+
def test_exponential_support(self):
180+
"""Test that exponential distribution has correct support [0, ∞)."""
181+
dist = self.exponential_dist_example
182+
183+
assert dist.support is not None
184+
assert isinstance(dist.support, ContinuousSupport)
185+
186+
assert dist.support.left == 0.0
187+
assert dist.support.right == float("inf")
188+
assert dist.support.left_closed
189+
assert not dist.support.right_closed
190+
191+
# Test containment
192+
assert dist.support.contains(0.0) is True
193+
assert dist.support.contains(1.0) is True
194+
assert dist.support.contains(-0.1) is False
195+
assert dist.support.contains(float("inf")) is False
196+
197+
# Test array
198+
test_points = np.array([-0.1, 0.0, 1.0, 10.0])
199+
expected = np.array([False, True, True, True])
200+
results = dist.support.contains(test_points)
201+
np.testing.assert_array_equal(results, expected)
202+
203+
assert dist.support.shape == ContinuousSupportShape1D.RAY_RIGHT
204+
205+
206+
class TestExponentialFamilyEdgeCases(BaseDistributionTest):
207+
"""Test edge cases and error conditions for exponential distribution."""
208+
209+
def setup_method(self):
210+
"""Setup before each test method."""
211+
registry = configure_families_register()
212+
self.exponential_family = registry.get(FamilyName.EXPONENTIAL)
213+
214+
def test_invalid_parameterization(self):
215+
"""Test error for invalid parameterization name."""
216+
with pytest.raises(KeyError):
217+
self.exponential_family.distribution(parametrization_name="invalid_name", lambda_=1.0)
218+
219+
def test_missing_parameters(self):
220+
"""Test error for missing required parameters."""
221+
with pytest.raises(TypeError):
222+
self.exponential_family.distribution() # Missing lambda_
223+
224+
def test_invalid_probability_ppf(self):
225+
"""Test PPF with invalid probability values."""
226+
dist = self.exponential_family(lambda_=1.0)
227+
ppf = dist.query_method(CharacteristicName.PPF)
228+
229+
# Test boundaries
230+
assert ppf(0.0) == 0.0
231+
assert ppf(1.0) == float("inf")
232+
233+
# Test invalid probabilities
234+
with pytest.raises(ValueError):
235+
ppf(-0.1)
236+
with pytest.raises(ValueError):
237+
ppf(1.1)
238+
239+
def test_characteristic_function_at_zero(self):
240+
"""Test characteristic function at zero returns 1."""
241+
dist = self.exponential_family(lambda_=1.0)
242+
char_func = dist.query_method(CharacteristicName.CF)
243+
244+
cf_value_zero = char_func(0.0)
245+
assert abs(cf_value_zero.real - 1.0) < self.CALCULATION_PRECISION
246+
assert abs(cf_value_zero.imag) < self.CALCULATION_PRECISION
247+
248+
def test_characteristic_function_large_t(self):
249+
"""Test characteristic function with large t."""
250+
dist = self.exponential_family(lambda_=1.0)
251+
char_func = dist.query_method(CharacteristicName.CF)
252+
253+
cf_value_large = char_func(1000.0)
254+
assert np.iscomplexobj(cf_value_large)
255+
assert abs(cf_value_large) <= 1.0

0 commit comments

Comments
 (0)