Skip to content

Commit 90e378b

Browse files
refractor simulator.models.test
1 parent aa154e1 commit 90e378b

File tree

2 files changed

+55
-66
lines changed

2 files changed

+55
-66
lines changed

stingray/simulator/models.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ def fit_deriv(x, x_0, fwhm, value, power_coeff):
7272
denom = mod_x_pc + fwhm_pc
7373
denom_sq = np.power(denom, 2)
7474

75-
d_x = -1.0 * num / denom_sq * (power_coeff * mod_x_pc / np.abs(x - x_0)) * np.sign(x - x_0)
76-
d_x_0 = -d_x
75+
d_x_0 = 1.0 * num / denom_sq * (power_coeff * mod_x_pc / np.abs(x - x_0)) * np.sign(x - x_0)
7776
d_value = fwhm_pc / denom
7877

7978
pre_compute = 1.0 / 2.0 * power_coeff * fwhm_pc / (fwhm / 2)
@@ -87,7 +86,7 @@ def fit_deriv(x, x_0, fwhm, value, power_coeff):
8786
- num * (np.log(abs(x - x_0)) * mod_x_pc + np.log(fwhm / 2) * fwhm_pc)
8887
)
8988
)
90-
return [d_x, d_x_0, d_value, d_fwhm, d_power_coeff]
89+
return [d_x_0, d_value, d_fwhm, d_power_coeff]
9190

9291
def bounding_box(self, factor=25):
9392
"""Tuple defining the default ``bounding_box`` limits,
@@ -156,9 +155,9 @@ class SmoothBrokenPowerLaw(Fittable1DModel):
156155
"""
157156

158157
norm = Parameter(default=1.0, description="normalization frequency")
159-
break_freq = Parameter(default=1.0, description="Break frequency")
160158
gamma_low = Parameter(default=-1.0, description="Power law index for f --> zero")
161159
gamma_high = Parameter(default=1.0, description="Power law index for f --> infinity")
160+
break_freq = Parameter(default=1.0, description="Break frequency")
162161

163162
def _norm_validator(self, value):
164163
if np.any(value <= 0):
@@ -189,19 +188,15 @@ def evaluate(x, norm, gamma_low, gamma_high, break_freq):
189188
threshold = 30 # corresponding to exp(30) ~ 1e13
190189
i = logt > threshold
191190
if i.max():
192-
log_f = (
193-
np.log(norm) - gamma_low * np.log(x[i]) + (gamma_low - gamma_high) * np.log(xx[i])
194-
)
195-
f[i] = np.exp(log_f)
191+
f[i] = norm * np.power(x[i], -gamma_low) * np.power(xx[i], gamma_low - gamma_high)
196192

197193
i = logt < -threshold
198194
if i.max():
199-
log_f = np.log(norm) - gamma_low * np.log(x[i])
200-
f[i] = np.exp(log_f)
195+
f[i] = norm * np.power(x[i], -gamma_low)
201196

202197
i = np.abs(logt) <= threshold
203198
if i.max():
204-
# In this case the `t` value is "comparable" to 1, hence we
199+
# In this case the `t` value is "comparable" to 1, hence
205200
# we will evaluate the whole formula.
206201
f[i] = (
207202
norm
@@ -230,10 +225,7 @@ def fit_deriv(x, norm, gamma_low, gamma_high, break_freq):
230225
threshold = 30 # (see comments in SmoothBrokenPowerLaw.evaluate)
231226
i = logt > threshold
232227
if i.max():
233-
log_f = (
234-
np.log(norm) - gamma_low * np.log(x[i]) + (gamma_low - gamma_high) * np.log(xx[i])
235-
)
236-
f[i] = np.exp(log_f)
228+
f[i] = norm * np.power(x[i], -gamma_low) * np.power(xx[i], gamma_low - gamma_high)
237229

238230
d_norm[i] = f[i] / norm
239231
d_gamma_low[i] = f[i] * (-np.log(x[i]) + np.log(xx[i]))
@@ -242,8 +234,7 @@ def fit_deriv(x, norm, gamma_low, gamma_high, break_freq):
242234

243235
i = logt < -threshold
244236
if i.max():
245-
log_f = np.log(norm) - gamma_low * np.log(x[i])
246-
f[i] = np.exp(log_f)
237+
f[i] = norm * np.power(x[i], -gamma_low)
247238

248239
d_norm[i] = f[i] / norm
249240
d_gamma_low[i] = -f[i] * np.log(x[i])

stingray/simulator/tests/test_models.py

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,72 +9,70 @@
99

1010

1111
class TestModel(object):
12+
@classmethod
13+
def setup_class(self):
14+
self.lorentz1D = models.GeneralizedLorentz1D(x_0=3, fwhm=32, value=2.5, power_coeff=2)
15+
self.smoothPowerlaw = models.SmoothBrokenPowerLaw(
16+
norm=1, gamma_low=-2, gamma_high=2, break_freq=10
17+
)
18+
19+
def test_model_param(self):
20+
lorentz1D = self.lorentz1D
21+
smoothPowerlaw = self.smoothPowerlaw
22+
23+
assert np.allclose(smoothPowerlaw.parameters, np.array([1, -2, 2, 10]))
24+
assert np.allclose(lorentz1D.parameters, np.array([3.0, 32.0, 2.5, 2.0]))
25+
26+
assert np.array_equal(
27+
lorentz1D.param_names, np.array(["x_0", "fwhm", "value", "power_coeff"])
28+
)
29+
assert np.array_equal(
30+
smoothPowerlaw.param_names, np.array(["norm", "gamma_low", "gamma_high", "break_freq"])
31+
)
1232

1333
def test_power_coeff(self):
1434
with pytest.raises(
1535
InputParameterError, match="The power coefficient should be greater than zero."
1636
):
1737
models.GeneralizedLorentz1D(x_0=2, fwhm=100, value=3, power_coeff=-1)
1838

19-
def test_lorentz_model(self):
20-
model = models.GeneralizedLorentz1D(x_0=3, fwhm=32, value=2.5, power_coeff=2)
39+
@pytest.mark.parametrize(
40+
"model, yy_func, params",
41+
[
42+
(models.SmoothBrokenPowerLaw, models.smoothbknpo, [1, -2, 2, 10]),
43+
(models.GeneralizedLorentz1D, models.generalized_lorentzian, [3, 32, 2.5, 2]),
44+
],
45+
)
46+
def test_model_evaluate(self, model, yy_func, params):
47+
model = model(*params)
2148
xx = np.linspace(2, 4, 6)
2249
yy = model(xx)
23-
yy_ref = [
24-
2.4902723735,
25-
2.4964893119,
26-
2.4996094360,
27-
2.4996094360,
28-
2.4964893119,
29-
2.4902723735,
30-
]
50+
yy_ref = yy_func(xx, params)
51+
3152
assert_allclose(yy, yy_ref, rtol=0, atol=1e-8)
53+
assert xx.shape == yy.shape == yy_ref.shape
3254

33-
def test_SmoothBrokenPowerLaw_fit_deriv(self):
34-
x_lim = [0.01, 100]
55+
@pytest.mark.parametrize(
56+
"model, x_lim",
57+
[
58+
(models.SmoothBrokenPowerLaw(1, -2, 2, 10), [0.01, 70]),
59+
(models.GeneralizedLorentz1D(3, 32, 2.5, 2), [-10, 10]),
60+
],
61+
)
62+
def test_model_fitting(self, model, x_lim):
3563
x = np.logspace(x_lim[0], x_lim[1], 100)
3664

37-
model_with_deriv = models.SmoothBrokenPowerLaw(1, 10, -2, 2)
38-
model_no_deriv = models.SmoothBrokenPowerLaw(1, 10, -2, 2)
65+
model_with_deriv = model
66+
model_no_deriv = model
3967

4068
# add 10% noise to the amplitude
41-
# fmt: off
42-
rsn_rand_1234567890 = np.array(
43-
[
44-
0.61879477, 0.59162363, 0.88868359, 0.89165480, 0.45756748,
45-
0.77818808, 0.26706377, 0.99610621, 0.54009489, 0.53752161,
46-
0.40099938, 0.70540579, 0.40518559, 0.94999075, 0.03075388,
47-
0.13602495, 0.08297726, 0.42352224, 0.23449723, 0.74743526,
48-
0.65177865, 0.68998682, 0.16413419, 0.87642114, 0.44733314,
49-
0.57871104, 0.52377835, 0.62689056, 0.34869427, 0.26209748,
50-
0.07498055, 0.17940570, 0.82999425, 0.98759822, 0.11326099,
51-
0.63846415, 0.73056694, 0.88321124, 0.52721004, 0.66487673,
52-
0.74209309, 0.94083846, 0.70123128, 0.29534353, 0.76134369,
53-
0.77593881, 0.36985514, 0.89519067, 0.33082813, 0.86108824,
54-
0.76897859, 0.61343376, 0.43870907, 0.91913538, 0.76958966,
55-
0.51063556, 0.04443249, 0.57463611, 0.31382006, 0.41221713,
56-
0.21531811, 0.03237521, 0.04166386, 0.73109303, 0.74556052,
57-
0.64716325, 0.77575353, 0.64599254, 0.16885816, 0.48485480,
58-
0.53844248, 0.99690349, 0.23657074, 0.04119088, 0.46501519,
59-
0.35739006, 0.23002665, 0.53420791, 0.71639475, 0.81857486,
60-
0.73994342, 0.07948837, 0.75688276, 0.13240193, 0.48465576,
61-
0.20624753, 0.02298276, 0.54257873, 0.68123230, 0.35887468,
62-
0.36296147, 0.67368397, 0.29505730, 0.66558885, 0.93652252,
63-
0.36755130, 0.91787687, 0.75922703, 0.48668067, 0.45967890
64-
]
65-
)
66-
# fmt: on
67-
# remove non-finite values from x, rsn_rand_1234567890
68-
# to fit the data because
69-
# these value results a non-finite output
70-
x = np.delete(x, 26)
71-
rsn_rand_1234567890 = np.delete(rsn_rand_1234567890, 26)
72-
73-
n = 0.1 * (rsn_rand_1234567890 - 0.5)
69+
rng = np.random.default_rng(0)
70+
rsn_rand_0 = rng.random(x.shape)
71+
n = 0.1 * (rsn_rand_0 - 0.5)
7472

7573
data = model_with_deriv(x) + n
76-
fitter_with_deriv = fitting.LevMarLSQFitter()
74+
fitter_with_deriv = fitting.LMLSQFitter()
7775
new_model_with_deriv = fitter_with_deriv(model_with_deriv, x, data)
78-
fitter_no_deriv = fitting.LevMarLSQFitter()
76+
fitter_no_deriv = fitting.LMLSQFitter()
7977
new_model_no_deriv = fitter_no_deriv(model_no_deriv, x, data, estimate_jacobian=True)
8078
assert_allclose(new_model_with_deriv.parameters, new_model_no_deriv.parameters, atol=0.5)

0 commit comments

Comments
 (0)