Skip to content

Commit 4b7594e

Browse files
committed
Expanded test suite and fixed nan-insertion bug
1 parent 53fc429 commit 4b7594e

14 files changed

+1470
-16
lines changed

pixi.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description = "Unified liNe Integration Turbo Engine"
44
readme = "README.md"
55
authors = [{name = "Raphael Erik Hviding", email = "raphael.hviding@gmail.com"}]
66
requires-python = ">= 3.12"
7-
version = "1.4.0"
7+
version = "1.4.1"
88
dependencies = [
99
"numpyro>=0.20.0,<0.21",
1010
"astropy>=7.2.0,<8",

tests/test_continuum.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,3 +1179,184 @@ def test_bspline_scales_knots(self):
11791179
prepared = f._prepare(u.AA, u.um)
11801180
assert prepared._knots[4] == pytest.approx(10000.0)
11811181
assert prepared._n_basis == f._n_basis
1182+
1183+
1184+
# ---------------------------------------------------------------------------
1185+
# is_linear property
1186+
# ---------------------------------------------------------------------------
1187+
1188+
_FLUX_UNIT = u.erg / (u.s * u.cm**2 * u.AA)
1189+
_WL_UNIT = u.um
1190+
_LINEAR_FORMS = [Linear(), Polynomial(2), Chebyshev(2, 0.1),
1191+
BSpline(jnp.array([0.9]*4 + [1.0] + [1.1]*4), degree=3),
1192+
Bernstein(3, 0.9, 1.1)]
1193+
_NONLINEAR_FORMS = [PowerLaw(), Blackbody(), ModifiedBlackbody(), AttenuatedBlackbody()]
1194+
1195+
1196+
class TestIsLinear:
1197+
@pytest.mark.parametrize('form', _LINEAR_FORMS)
1198+
def test_linear_forms_return_true(self, form):
1199+
assert form.is_linear is True
1200+
1201+
@pytest.mark.parametrize('form', _NONLINEAR_FORMS)
1202+
def test_nonlinear_forms_return_false(self, form):
1203+
assert form.is_linear is False
1204+
1205+
1206+
# ---------------------------------------------------------------------------
1207+
# param_units method
1208+
# ---------------------------------------------------------------------------
1209+
1210+
1211+
class TestParamUnits:
1212+
@pytest.mark.parametrize('form', [
1213+
Linear(), PowerLaw(),
1214+
Polynomial(2),
1215+
Chebyshev(2, 0.1),
1216+
BSpline(jnp.array([0.9]*4 + [1.0] + [1.1]*4), degree=3),
1217+
Bernstein(3, 0.9, 1.1),
1218+
Blackbody(), ModifiedBlackbody(), AttenuatedBlackbody(),
1219+
])
1220+
def test_param_units_returns_dict(self, form):
1221+
pu = form.param_units(_FLUX_UNIT, _WL_UNIT)
1222+
assert isinstance(pu, dict)
1223+
assert 'scale' in pu
1224+
# scale should have apply_cs=True and flux_unit
1225+
apply_cs, phys_unit = pu['scale']
1226+
assert apply_cs is True
1227+
1228+
def test_linear_slope_unit(self):
1229+
pu = Linear().param_units(_FLUX_UNIT, _WL_UNIT)
1230+
_, slope_unit = pu['slope']
1231+
assert slope_unit.is_equivalent(_FLUX_UNIT / _WL_UNIT)
1232+
1233+
def test_powerlaw_beta_dimensionless(self):
1234+
pu = PowerLaw().param_units(_FLUX_UNIT, _WL_UNIT)
1235+
_, beta_unit = pu['beta']
1236+
assert beta_unit is None
1237+
1238+
def test_blackbody_temperature_unit(self):
1239+
pu = Blackbody().param_units(_FLUX_UNIT, _WL_UNIT)
1240+
_, temp_unit = pu['temperature']
1241+
assert temp_unit == u.K
1242+
1243+
1244+
# ---------------------------------------------------------------------------
1245+
# default_priors for parameterized forms
1246+
# ---------------------------------------------------------------------------
1247+
1248+
1249+
class TestDefaultPriors:
1250+
def test_chebyshev_default_priors_order2(self):
1251+
priors = Chebyshev(order=2).default_priors(region_center=1.5)
1252+
assert 'c1' in priors
1253+
assert 'c2' in priors
1254+
assert isinstance(priors['normalization_wavelength'], Fixed)
1255+
assert priors['normalization_wavelength'].value == pytest.approx(1.5)
1256+
1257+
def test_polynomial_default_priors_degree2(self):
1258+
priors = Polynomial(degree=2).default_priors(region_center=2.0)
1259+
assert 'c1' in priors
1260+
assert 'c2' in priors
1261+
1262+
def test_bspline_default_priors(self):
1263+
knots = jnp.array([0.9]*4 + [1.0, 1.05, 1.1] + [1.1]*4)
1264+
b = BSpline(knots, degree=3)
1265+
priors = b.default_priors(region_center=1.0)
1266+
assert 'scale' in priors
1267+
for i in range(1, b.n_basis):
1268+
assert f'coeff_{i}' in priors
1269+
1270+
def test_bernstein_default_priors(self):
1271+
b = Bernstein(degree=3, wavelength_min=0.9, wavelength_max=1.1)
1272+
priors = b.default_priors(region_center=1.0)
1273+
assert 'scale' in priors
1274+
assert 'coeff_1' in priors
1275+
1276+
1277+
# ---------------------------------------------------------------------------
1278+
# __eq__ cross-type (NotImplemented) and __hash__ for all forms
1279+
# ---------------------------------------------------------------------------
1280+
1281+
1282+
class TestFormEqHash:
1283+
@pytest.mark.parametrize('form', [
1284+
Linear(), PowerLaw(), Polynomial(2), Chebyshev(2, 0.1),
1285+
Blackbody(), ModifiedBlackbody(), AttenuatedBlackbody(),
1286+
BSpline(jnp.array([0.9]*4 + [1.0] + [1.1]*4), degree=3),
1287+
Bernstein(3, 0.9, 1.1),
1288+
])
1289+
def test_hashable(self, form):
1290+
assert isinstance(hash(form), int)
1291+
1292+
def test_different_types_not_equal(self):
1293+
# ContinuumForm base __eq__ returns NotImplemented for different types
1294+
assert Linear() != PowerLaw()
1295+
assert Blackbody() != ModifiedBlackbody()
1296+
assert Polynomial(2) != Chebyshev(2)
1297+
1298+
def test_polynomial_eq_hash(self):
1299+
assert Polynomial(2) == Polynomial(2)
1300+
assert Polynomial(2) != Polynomial(3)
1301+
assert hash(Polynomial(2)) == hash(Polynomial(2))
1302+
1303+
def test_chebyshev_eq_hash(self):
1304+
assert Chebyshev(2, 0.1) == Chebyshev(2, 0.1)
1305+
assert Chebyshev(2, 0.1) != Chebyshev(2, 0.2)
1306+
assert isinstance(hash(Chebyshev(2, 0.1)), int)
1307+
1308+
def test_attenuated_blackbody_eq_hash(self):
1309+
assert AttenuatedBlackbody(0.55) == AttenuatedBlackbody(0.55)
1310+
assert AttenuatedBlackbody(0.55) != AttenuatedBlackbody(0.50)
1311+
assert isinstance(hash(AttenuatedBlackbody(0.55)), int)
1312+
1313+
def test_bspline_eq_hash(self):
1314+
knots = jnp.array([0.9]*4 + [1.0] + [1.1]*4)
1315+
b1 = BSpline(knots, degree=3)
1316+
b2 = BSpline(knots, degree=3)
1317+
assert b1 == b2
1318+
assert isinstance(hash(b1), int)
1319+
1320+
def test_bernstein_eq_hash(self):
1321+
b1 = Bernstein(3, 0.9, 1.1)
1322+
b2 = Bernstein(3, 0.9, 1.1)
1323+
assert b1 == b2
1324+
assert isinstance(hash(b1), int)
1325+
1326+
1327+
# ---------------------------------------------------------------------------
1328+
# _adapt_for_observed_region
1329+
# ---------------------------------------------------------------------------
1330+
1331+
1332+
class TestAdaptForObservedRegion:
1333+
def test_linear_returns_self(self):
1334+
f = Linear()
1335+
assert f._adapt_for_observed_region(1.0, 2.0) is f
1336+
1337+
def test_chebyshev_updates_half_width(self):
1338+
f = Chebyshev(order=2, half_width=0.5)
1339+
adapted = f._adapt_for_observed_region(0.9, 1.1)
1340+
assert adapted._half_width == pytest.approx((1.1 - 0.9) / 2.0)
1341+
assert adapted._order == 2
1342+
1343+
def test_bspline_rescales_knots(self):
1344+
knots = jnp.array([0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0, 1.0])
1345+
f = BSpline(knots, degree=3)
1346+
adapted = f._adapt_for_observed_region(0.9, 1.1)
1347+
assert float(adapted._knots[0]) == pytest.approx(0.9)
1348+
assert float(adapted._knots[-1]) == pytest.approx(1.1)
1349+
1350+
def test_bspline_identity_knots(self):
1351+
# If all knots are equal, should return self (no rescaling possible)
1352+
knots = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0])
1353+
f = BSpline(knots, degree=3)
1354+
adapted = f._adapt_for_observed_region(0.9, 1.1)
1355+
assert adapted is f
1356+
1357+
def test_bernstein_updates_bounds(self):
1358+
f = Bernstein(degree=3, wavelength_min=0.0, wavelength_max=1.0)
1359+
adapted = f._adapt_for_observed_region(0.9, 1.1)
1360+
assert adapted._wavelength_min == pytest.approx(0.9)
1361+
assert adapted._wavelength_max == pytest.approx(1.1)
1362+
assert adapted._degree == 3

0 commit comments

Comments
 (0)