@@ -1179,3 +1179,184 @@ def test_bspline_scales_knots(self):
11791179 prepared = f ._prepare (u .AA , u .um )
11801180 assert prepared ._knots [4 ] == pytest .approx (10000.0 )
11811181 assert prepared ._n_basis == f ._n_basis
1182+
1183+
1184+ # ---------------------------------------------------------------------------
1185+ # is_linear property
1186+ # ---------------------------------------------------------------------------
1187+
1188+ _FLUX_UNIT = u .erg / (u .s * u .cm ** 2 * u .AA )
1189+ _WL_UNIT = u .um
1190+ _LINEAR_FORMS = [Linear (), Polynomial (2 ), Chebyshev (2 , 0.1 ),
1191+ BSpline (jnp .array ([0.9 ]* 4 + [1.0 ] + [1.1 ]* 4 ), degree = 3 ),
1192+ Bernstein (3 , 0.9 , 1.1 )]
1193+ _NONLINEAR_FORMS = [PowerLaw (), Blackbody (), ModifiedBlackbody (), AttenuatedBlackbody ()]
1194+
1195+
1196+ class TestIsLinear :
1197+ @pytest .mark .parametrize ('form' , _LINEAR_FORMS )
1198+ def test_linear_forms_return_true (self , form ):
1199+ assert form .is_linear is True
1200+
1201+ @pytest .mark .parametrize ('form' , _NONLINEAR_FORMS )
1202+ def test_nonlinear_forms_return_false (self , form ):
1203+ assert form .is_linear is False
1204+
1205+
1206+ # ---------------------------------------------------------------------------
1207+ # param_units method
1208+ # ---------------------------------------------------------------------------
1209+
1210+
1211+ class TestParamUnits :
1212+ @pytest .mark .parametrize ('form' , [
1213+ Linear (), PowerLaw (),
1214+ Polynomial (2 ),
1215+ Chebyshev (2 , 0.1 ),
1216+ BSpline (jnp .array ([0.9 ]* 4 + [1.0 ] + [1.1 ]* 4 ), degree = 3 ),
1217+ Bernstein (3 , 0.9 , 1.1 ),
1218+ Blackbody (), ModifiedBlackbody (), AttenuatedBlackbody (),
1219+ ])
1220+ def test_param_units_returns_dict (self , form ):
1221+ pu = form .param_units (_FLUX_UNIT , _WL_UNIT )
1222+ assert isinstance (pu , dict )
1223+ assert 'scale' in pu
1224+ # scale should have apply_cs=True and flux_unit
1225+ apply_cs , phys_unit = pu ['scale' ]
1226+ assert apply_cs is True
1227+
1228+ def test_linear_slope_unit (self ):
1229+ pu = Linear ().param_units (_FLUX_UNIT , _WL_UNIT )
1230+ _ , slope_unit = pu ['slope' ]
1231+ assert slope_unit .is_equivalent (_FLUX_UNIT / _WL_UNIT )
1232+
1233+ def test_powerlaw_beta_dimensionless (self ):
1234+ pu = PowerLaw ().param_units (_FLUX_UNIT , _WL_UNIT )
1235+ _ , beta_unit = pu ['beta' ]
1236+ assert beta_unit is None
1237+
1238+ def test_blackbody_temperature_unit (self ):
1239+ pu = Blackbody ().param_units (_FLUX_UNIT , _WL_UNIT )
1240+ _ , temp_unit = pu ['temperature' ]
1241+ assert temp_unit == u .K
1242+
1243+
1244+ # ---------------------------------------------------------------------------
1245+ # default_priors for parameterized forms
1246+ # ---------------------------------------------------------------------------
1247+
1248+
1249+ class TestDefaultPriors :
1250+ def test_chebyshev_default_priors_order2 (self ):
1251+ priors = Chebyshev (order = 2 ).default_priors (region_center = 1.5 )
1252+ assert 'c1' in priors
1253+ assert 'c2' in priors
1254+ assert isinstance (priors ['normalization_wavelength' ], Fixed )
1255+ assert priors ['normalization_wavelength' ].value == pytest .approx (1.5 )
1256+
1257+ def test_polynomial_default_priors_degree2 (self ):
1258+ priors = Polynomial (degree = 2 ).default_priors (region_center = 2.0 )
1259+ assert 'c1' in priors
1260+ assert 'c2' in priors
1261+
1262+ def test_bspline_default_priors (self ):
1263+ knots = jnp .array ([0.9 ]* 4 + [1.0 , 1.05 , 1.1 ] + [1.1 ]* 4 )
1264+ b = BSpline (knots , degree = 3 )
1265+ priors = b .default_priors (region_center = 1.0 )
1266+ assert 'scale' in priors
1267+ for i in range (1 , b .n_basis ):
1268+ assert f'coeff_{ i } ' in priors
1269+
1270+ def test_bernstein_default_priors (self ):
1271+ b = Bernstein (degree = 3 , wavelength_min = 0.9 , wavelength_max = 1.1 )
1272+ priors = b .default_priors (region_center = 1.0 )
1273+ assert 'scale' in priors
1274+ assert 'coeff_1' in priors
1275+
1276+
1277+ # ---------------------------------------------------------------------------
1278+ # __eq__ cross-type (NotImplemented) and __hash__ for all forms
1279+ # ---------------------------------------------------------------------------
1280+
1281+
1282+ class TestFormEqHash :
1283+ @pytest .mark .parametrize ('form' , [
1284+ Linear (), PowerLaw (), Polynomial (2 ), Chebyshev (2 , 0.1 ),
1285+ Blackbody (), ModifiedBlackbody (), AttenuatedBlackbody (),
1286+ BSpline (jnp .array ([0.9 ]* 4 + [1.0 ] + [1.1 ]* 4 ), degree = 3 ),
1287+ Bernstein (3 , 0.9 , 1.1 ),
1288+ ])
1289+ def test_hashable (self , form ):
1290+ assert isinstance (hash (form ), int )
1291+
1292+ def test_different_types_not_equal (self ):
1293+ # ContinuumForm base __eq__ returns NotImplemented for different types
1294+ assert Linear () != PowerLaw ()
1295+ assert Blackbody () != ModifiedBlackbody ()
1296+ assert Polynomial (2 ) != Chebyshev (2 )
1297+
1298+ def test_polynomial_eq_hash (self ):
1299+ assert Polynomial (2 ) == Polynomial (2 )
1300+ assert Polynomial (2 ) != Polynomial (3 )
1301+ assert hash (Polynomial (2 )) == hash (Polynomial (2 ))
1302+
1303+ def test_chebyshev_eq_hash (self ):
1304+ assert Chebyshev (2 , 0.1 ) == Chebyshev (2 , 0.1 )
1305+ assert Chebyshev (2 , 0.1 ) != Chebyshev (2 , 0.2 )
1306+ assert isinstance (hash (Chebyshev (2 , 0.1 )), int )
1307+
1308+ def test_attenuated_blackbody_eq_hash (self ):
1309+ assert AttenuatedBlackbody (0.55 ) == AttenuatedBlackbody (0.55 )
1310+ assert AttenuatedBlackbody (0.55 ) != AttenuatedBlackbody (0.50 )
1311+ assert isinstance (hash (AttenuatedBlackbody (0.55 )), int )
1312+
1313+ def test_bspline_eq_hash (self ):
1314+ knots = jnp .array ([0.9 ]* 4 + [1.0 ] + [1.1 ]* 4 )
1315+ b1 = BSpline (knots , degree = 3 )
1316+ b2 = BSpline (knots , degree = 3 )
1317+ assert b1 == b2
1318+ assert isinstance (hash (b1 ), int )
1319+
1320+ def test_bernstein_eq_hash (self ):
1321+ b1 = Bernstein (3 , 0.9 , 1.1 )
1322+ b2 = Bernstein (3 , 0.9 , 1.1 )
1323+ assert b1 == b2
1324+ assert isinstance (hash (b1 ), int )
1325+
1326+
1327+ # ---------------------------------------------------------------------------
1328+ # _adapt_for_observed_region
1329+ # ---------------------------------------------------------------------------
1330+
1331+
1332+ class TestAdaptForObservedRegion :
1333+ def test_linear_returns_self (self ):
1334+ f = Linear ()
1335+ assert f ._adapt_for_observed_region (1.0 , 2.0 ) is f
1336+
1337+ def test_chebyshev_updates_half_width (self ):
1338+ f = Chebyshev (order = 2 , half_width = 0.5 )
1339+ adapted = f ._adapt_for_observed_region (0.9 , 1.1 )
1340+ assert adapted ._half_width == pytest .approx ((1.1 - 0.9 ) / 2.0 )
1341+ assert adapted ._order == 2
1342+
1343+ def test_bspline_rescales_knots (self ):
1344+ knots = jnp .array ([0.0 , 0.0 , 0.0 , 0.0 , 0.5 , 1.0 , 1.0 , 1.0 , 1.0 ])
1345+ f = BSpline (knots , degree = 3 )
1346+ adapted = f ._adapt_for_observed_region (0.9 , 1.1 )
1347+ assert float (adapted ._knots [0 ]) == pytest .approx (0.9 )
1348+ assert float (adapted ._knots [- 1 ]) == pytest .approx (1.1 )
1349+
1350+ def test_bspline_identity_knots (self ):
1351+ # If all knots are equal, should return self (no rescaling possible)
1352+ knots = jnp .array ([1.0 , 1.0 , 1.0 , 1.0 , 1.0 ])
1353+ f = BSpline (knots , degree = 3 )
1354+ adapted = f ._adapt_for_observed_region (0.9 , 1.1 )
1355+ assert adapted is f
1356+
1357+ def test_bernstein_updates_bounds (self ):
1358+ f = Bernstein (degree = 3 , wavelength_min = 0.0 , wavelength_max = 1.0 )
1359+ adapted = f ._adapt_for_observed_region (0.9 , 1.1 )
1360+ assert adapted ._wavelength_min == pytest .approx (0.9 )
1361+ assert adapted ._wavelength_max == pytest .approx (1.1 )
1362+ assert adapted ._degree == 3
0 commit comments