1111- Complex redshift hierarchies (systemic → NLR → BLR → outflow)
1212"""
1313
14- import warnings
1514
1615import astropy .units as u
1716import jax .numpy as jnp
2019from jax import random
2120from numpyro .infer import Predictive
2221
23- from unite import line , model , prior
22+ from unite import model
2423from unite .disperser .generic import SimpleDisperser
25- from unite .line .config import FWHM , Flux , LineConfiguration , Param , Redshift
26- from unite .prior import (
27- Fixed ,
28- Parameter ,
29- ParameterRef ,
30- TruncatedNormal ,
31- Uniform ,
32- topological_sort ,
33- )
24+ from unite .line .config import FWHM , Flux , LineConfiguration , Redshift
25+ from unite .prior import ParameterRef , TruncatedNormal , Uniform , topological_sort
3426from unite .spectrum import Spectra , Spectrum
3527
36-
3728# ---------------------------------------------------------------------------
3829# Helpers
3930# ---------------------------------------------------------------------------
4233def _make_spectrum (wl_range = (6400 , 6700 ), npix = 200 , name = 'test' ):
4334 """Create a test spectrum covering the given range."""
4435 wl = np .linspace (* wl_range , npix ) * u .AA
45- disperser = SimpleDisperser (
46- wavelength = wl .value , unit = u .AA , R = 3000.0 , name = name ,
47- )
36+ disperser = SimpleDisperser (wavelength = wl .value , unit = u .AA , R = 3000.0 , name = name )
4837 low = wl - 0.5 * np .gradient (wl )
4938 high = wl + 0.5 * np .gradient (wl )
5039 flux_unit = u .Unit ('1e-17 erg / (s cm2 AA)' )
5140 rng = np .random .default_rng (42 )
5241 flux = (10.0 + rng .normal (0 , 1 , npix )) * flux_unit
5342 error = np .full (npix , 1.0 ) * flux_unit
5443 return Spectrum (
55- low = low , high = high , flux = flux , error = error ,
56- disperser = disperser , name = name ,
44+ low = low , high = high , flux = flux , error = error , disperser = disperser , name = name
5745 )
5846
5947
6048def _prepare_and_build (line_config , spectra , cont_config = None ):
6149 """Prepare spectra and build model."""
6250 spectra .prepare (line_config , cont_config )
63- spectra .compute_scales (
64- spectra .prepared_line_config , spectra .prepared_cont_config ,
65- )
51+ spectra .compute_scales (spectra .prepared_line_config , spectra .prepared_cont_config )
6652 return model .ModelBuilder (
67- spectra .prepared_line_config , spectra .prepared_cont_config , spectra ,
53+ spectra .prepared_line_config , spectra .prepared_cont_config , spectra
6854 ).build ()
6955
7056
@@ -100,9 +86,7 @@ def test_four_level_chain(self):
10086 c = FWHM ('c' , prior = Uniform (low = b * 1.5 , high = 1000 ))
10187 d = FWHM ('d' , prior = Uniform (low = c + 100 , high = 5000 ))
10288
103- named_priors = {
104- 'a' : a .prior , 'b' : b .prior , 'c' : c .prior , 'd' : d .prior ,
105- }
89+ named_priors = {'a' : a .prior , 'b' : b .prior , 'c' : c .prior , 'd' : d .prior }
10690 param_to_name = {a : 'a' , b : 'b' , c : 'c' , d : 'd' }
10791 order = topological_sort (named_priors , param_to_name )
10892
@@ -114,16 +98,10 @@ def test_three_level_redshift_chain(self):
11498 z_nlr = Redshift (
11599 'z_nlr' ,
116100 prior = TruncatedNormal (
117- loc = z_sys ,
118- scale = 0.001 ,
119- low = z_sys - 0.005 ,
120- high = z_sys + 0.005 ,
101+ loc = z_sys , scale = 0.001 , low = z_sys - 0.005 , high = z_sys + 0.005
121102 ),
122103 )
123- z_out = Redshift (
124- 'z_out' ,
125- prior = Uniform (low = z_nlr - 0.01 , high = z_nlr ),
126- )
104+ z_out = Redshift ('z_out' , prior = Uniform (low = z_nlr - 0.01 , high = z_nlr ))
127105
128106 named_priors = {
129107 'z_sys' : z_sys .prior ,
@@ -171,8 +149,7 @@ def test_redshift_token_as_loc(self):
171149 """Redshift token passed directly as TruncatedNormal loc."""
172150 z_sys = Redshift ('z_sys' , prior = Uniform (- 0.01 , 0.01 ))
173151 z_nlr = Redshift (
174- 'z_nlr' ,
175- prior = TruncatedNormal (loc = z_sys , scale = 0.001 , low = - 0.02 , high = 0.02 ),
152+ 'z_nlr' , prior = TruncatedNormal (loc = z_sys , scale = 0.001 , low = - 0.02 , high = 0.02 )
176153 )
177154 assert z_sys in z_nlr .prior .dependencies ()
178155
@@ -193,12 +170,7 @@ def test_all_three_bounds_depend_on_different_params(self):
193170
194171 constrained = FWHM (
195172 'constrained' ,
196- prior = TruncatedNormal (
197- loc = center ,
198- scale = 50.0 ,
199- low = lower ,
200- high = upper ,
201- ),
173+ prior = TruncatedNormal (loc = center , scale = 50.0 , low = lower , high = upper ),
202174 )
203175
204176 deps = constrained .prior .dependencies ()
@@ -211,12 +183,7 @@ def test_loc_and_low_same_token(self):
211183 base = FWHM ('base' , prior = Uniform (100 , 500 ))
212184 derived = FWHM (
213185 'derived' ,
214- prior = TruncatedNormal (
215- loc = base + 100 ,
216- scale = 30.0 ,
217- low = base ,
218- high = 2000 ,
219- ),
186+ prior = TruncatedNormal (loc = base + 100 , scale = 30.0 , low = base , high = 2000 ),
220187 )
221188
222189 deps = derived .prior .dependencies ()
@@ -244,10 +211,7 @@ def test_flux_ratio_via_parameter_ref(self):
244211 f_weak = Flux (
245212 'NII_6549' ,
246213 prior = TruncatedNormal (
247- loc = f_strong / 2.95 ,
248- scale = 0.1 ,
249- low = f_strong / 4.0 ,
250- high = f_strong / 2.0 ,
214+ loc = f_strong / 2.95 , scale = 0.1 , low = f_strong / 4.0 , high = f_strong / 2.0
251215 ),
252216 )
253217
@@ -263,17 +227,11 @@ def test_flux_ratio_via_parameter_ref(self):
263227 def test_flux_chain (self ):
264228 """Three-line flux chain: Ha → [NII]6585 → [NII]6549."""
265229 f_ha = Flux ('Ha' , prior = Uniform (0 , 20 ))
266- f_nii_s = Flux (
267- 'NII_6585' ,
268- prior = Uniform (low = 0 , high = f_ha * 2 ),
269- )
230+ f_nii_s = Flux ('NII_6585' , prior = Uniform (low = 0 , high = f_ha * 2 ))
270231 f_nii_w = Flux (
271232 'NII_6549' ,
272233 prior = TruncatedNormal (
273- loc = f_nii_s / 2.95 ,
274- scale = 0.05 ,
275- low = f_nii_s / 4.0 ,
276- high = f_nii_s / 2.0 ,
234+ loc = f_nii_s / 2.95 , scale = 0.05 , low = f_nii_s / 4.0 , high = f_nii_s / 2.0
277235 ),
278236 )
279237
@@ -316,27 +274,15 @@ def _make_three_component_config(self):
316274 z_blr = Redshift (
317275 'z_blr' ,
318276 prior = TruncatedNormal (
319- loc = z_nlr ,
320- scale = 0.002 ,
321- low = z_nlr - 0.01 ,
322- high = z_nlr + 0.01 ,
277+ loc = z_nlr , scale = 0.002 , low = z_nlr - 0.01 , high = z_nlr + 0.01
323278 ),
324279 )
325- z_out = Redshift (
326- 'z_out' ,
327- prior = Uniform (low = z_nlr - 0.02 , high = z_nlr ),
328- )
280+ z_out = Redshift ('z_out' , prior = Uniform (low = z_nlr - 0.02 , high = z_nlr ))
329281
330282 # -- FWHM hierarchy --
331283 fwhm_narrow = FWHM ('fwhm_narrow' , prior = Uniform (50 , 500 ))
332- fwhm_broad = FWHM (
333- 'fwhm_broad' ,
334- prior = Uniform (low = fwhm_narrow + 200 , high = 5000 ),
335- )
336- fwhm_out = FWHM (
337- 'fwhm_out' ,
338- prior = Uniform (low = fwhm_broad , high = 8000 ),
339- )
284+ fwhm_broad = FWHM ('fwhm_broad' , prior = Uniform (low = fwhm_narrow + 200 , high = 5000 ))
285+ fwhm_out = FWHM ('fwhm_out' , prior = Uniform (low = fwhm_broad , high = 8000 ))
340286
341287 # -- Flux with doublet ratio --
342288 f_ha_n = Flux ('Ha_n' , prior = Uniform (0 , 10 ))
@@ -346,25 +292,40 @@ def _make_three_component_config(self):
346292 f_nii_w = Flux (
347293 'NII_w' ,
348294 prior = TruncatedNormal (
349- loc = f_nii_s / 2.95 ,
350- scale = 0.1 ,
351- low = f_nii_s / 4.0 ,
352- high = f_nii_s / 2.0 ,
295+ loc = f_nii_s / 2.95 , scale = 0.1 , low = f_nii_s / 4.0 , high = f_nii_s / 2.0
353296 ),
354297 )
355298
356299 lc = LineConfiguration ()
357300
358301 # Narrow lines
359- lc .add_line ('Ha' , 6564.61 * u .AA , redshift = z_nlr , fwhm_gauss = fwhm_narrow , flux = f_ha_n )
360- lc .add_line ('NII_6585' , 6585.27 * u .AA , redshift = z_nlr , fwhm_gauss = fwhm_narrow , flux = f_nii_s )
361- lc .add_line ('NII_6549' , 6549.86 * u .AA , redshift = z_nlr , fwhm_gauss = fwhm_narrow , flux = f_nii_w )
302+ lc .add_line (
303+ 'Ha' , 6564.61 * u .AA , redshift = z_nlr , fwhm_gauss = fwhm_narrow , flux = f_ha_n
304+ )
305+ lc .add_line (
306+ 'NII_6585' ,
307+ 6585.27 * u .AA ,
308+ redshift = z_nlr ,
309+ fwhm_gauss = fwhm_narrow ,
310+ flux = f_nii_s ,
311+ )
312+ lc .add_line (
313+ 'NII_6549' ,
314+ 6549.86 * u .AA ,
315+ redshift = z_nlr ,
316+ fwhm_gauss = fwhm_narrow ,
317+ flux = f_nii_w ,
318+ )
362319
363320 # Broad lines
364- lc .add_line ('Ha' , 6564.61 * u .AA , redshift = z_blr , fwhm_gauss = fwhm_broad , flux = f_ha_b )
321+ lc .add_line (
322+ 'Ha' , 6564.61 * u .AA , redshift = z_blr , fwhm_gauss = fwhm_broad , flux = f_ha_b
323+ )
365324
366325 # Outflow lines
367- lc .add_line ('Ha' , 6564.61 * u .AA , redshift = z_out , fwhm_gauss = fwhm_out , flux = f_ha_out )
326+ lc .add_line (
327+ 'Ha' , 6564.61 * u .AA , redshift = z_out , fwhm_gauss = fwhm_out , flux = f_ha_out
328+ )
368329
369330 return lc
370331
@@ -463,20 +424,13 @@ def test_fwhm_depends_on_two_parents(self):
463424 """A FWHM with low from one token and high from another."""
464425 lower = FWHM ('lower' , prior = Uniform (50 , 200 ))
465426 upper = FWHM ('upper' , prior = Uniform (800 , 2000 ))
466- mid = FWHM (
467- 'mid' ,
468- prior = Uniform (low = lower + 50 , high = upper - 50 ),
469- )
427+ mid = FWHM ('mid' , prior = Uniform (low = lower + 50 , high = upper - 50 ))
470428
471429 deps = mid .prior .dependencies ()
472430 assert lower in deps
473431 assert upper in deps
474432
475- named_priors = {
476- 'lower' : lower .prior ,
477- 'upper' : upper .prior ,
478- 'mid' : mid .prior ,
479- }
433+ named_priors = {'lower' : lower .prior , 'upper' : upper .prior , 'mid' : mid .prior }
480434 param_to_name = {lower : 'lower' , upper : 'upper' , mid : 'mid' }
481435 order = topological_sort (named_priors , param_to_name )
482436 assert order .index ('lower' ) < order .index ('mid' )
@@ -489,9 +443,7 @@ def test_diamond_with_convergent_child(self):
489443 c = FWHM ('c' , prior = Uniform (low = a , high = b ))
490444 d = FWHM ('d' , prior = Uniform (low = a + 50 , high = b - 50 ))
491445
492- named_priors = {
493- 'a' : a .prior , 'b' : b .prior , 'c' : c .prior , 'd' : d .prior ,
494- }
446+ named_priors = {'a' : a .prior , 'b' : b .prior , 'c' : c .prior , 'd' : d .prior }
495447 param_to_name = {a : 'a' , b : 'b' , c : 'c' , d : 'd' }
496448 order = topological_sort (named_priors , param_to_name )
497449
@@ -543,10 +495,7 @@ def test_nested_ref_in_truncated_normal(self):
543495 derived = FWHM (
544496 'derived' ,
545497 prior = TruncatedNormal (
546- loc = base * 1.5 + 50 ,
547- scale = 30.0 ,
548- low = base + 20 ,
549- high = base * 3 ,
498+ loc = base * 1.5 + 50 , scale = 30.0 , low = base + 20 , high = base * 3
550499 ),
551500 )
552501 context = {base : 200.0 }
@@ -595,10 +544,7 @@ def test_flux_ratio_serialization(self):
595544 f_weak = Flux (
596545 'f_w' ,
597546 prior = TruncatedNormal (
598- loc = f_strong / 2.95 ,
599- scale = 0.1 ,
600- low = f_strong / 4.0 ,
601- high = f_strong / 2.0 ,
547+ loc = f_strong / 2.95 , scale = 0.1 , low = f_strong / 4.0 , high = f_strong / 2.0
602548 ),
603549 )
604550
@@ -651,10 +597,7 @@ class TestEndToEndDeepDependencies:
651597 def test_narrow_broad_model_respects_ordering (self ):
652598 """Verify sampled broad FWHM > sampled narrow FWHM + offset."""
653599 fwhm_narrow = FWHM ('fwhm_narrow' , prior = Uniform (50 , 300 ))
654- fwhm_broad = FWHM (
655- 'fwhm_broad' ,
656- prior = Uniform (low = fwhm_narrow + 200 , high = 3000 ),
657- )
600+ fwhm_broad = FWHM ('fwhm_broad' , prior = Uniform (low = fwhm_narrow + 200 , high = 3000 ))
658601
659602 lc = LineConfiguration ()
660603 z = Redshift ('z' , prior = Uniform (- 0.005 , 0.005 ))
@@ -709,10 +652,7 @@ def test_redshift_hierarchy_model(self):
709652 z_nlr = Redshift (
710653 'z_nlr' ,
711654 prior = TruncatedNormal (
712- loc = z_sys ,
713- scale = 0.001 ,
714- low = z_sys - 0.003 ,
715- high = z_sys + 0.003 ,
655+ loc = z_sys , scale = 0.001 , low = z_sys - 0.003 , high = z_sys + 0.003
716656 ),
717657 )
718658
0 commit comments