Skip to content

Commit d2ccd7a

Browse files
committed
Fix CI/Lint issues
1 parent 6e19a87 commit d2ccd7a

File tree

13 files changed

+92
-164
lines changed

13 files changed

+92
-164
lines changed

tests/test_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22

33
import warnings
44

5-
import pytest
65
from astropy import units as u
76

87
from unite.config import Configuration
98
from unite.continuum import ContinuumConfiguration, Linear
109
from unite.disperser.base import FluxScale, RScale
1110
from unite.disperser.config import DispersersConfiguration
1211
from unite.instruments.nirspec import G235H, G395H
13-
from unite.line import FWHM, Flux, LineConfiguration, Redshift
12+
from unite.line import FWHM, LineConfiguration, Redshift
1413
from unite.prior import TruncatedNormal, Uniform
1514

1615

tests/test_continuum_functions.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import jax.numpy as jnp
44
import numpy as np
5-
import pytest
65
from scipy.special import comb as scipy_comb
76

87
from unite.continuum.functions import (
@@ -13,7 +12,6 @@
1312
planck_function,
1413
)
1514

16-
1715
# ---------------------------------------------------------------------------
1816
# Planck function
1917
# ---------------------------------------------------------------------------
@@ -109,11 +107,7 @@ def _make_clamped_knots(self, n_basis, degree, low=0.0, high=1.0):
109107
"""Create a clamped knot vector."""
110108
n_internal = n_basis - degree + 1
111109
internal = np.linspace(low, high, n_internal)
112-
knots = np.concatenate([
113-
np.full(degree, low),
114-
internal,
115-
np.full(degree, high),
116-
])
110+
knots = np.concatenate([np.full(degree, low), internal, np.full(degree, high)])
117111
return jnp.asarray(knots)
118112

119113
def test_basis_partition_of_unity(self):

tests/test_dependent_priors.py

Lines changed: 51 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
- Complex redshift hierarchies (systemic → NLR → BLR → outflow)
1212
"""
1313

14-
import warnings
1514

1615
import astropy.units as u
1716
import jax.numpy as jnp
@@ -20,20 +19,12 @@
2019
from jax import random
2120
from numpyro.infer import Predictive
2221

23-
from unite import line, model, prior
22+
from unite import model
2423
from unite.disperser.generic import SimpleDisperser
25-
from unite.line.config import FWHM, Flux, LineConfiguration, Param, Redshift
26-
from unite.prior import (
27-
Fixed,
28-
Parameter,
29-
ParameterRef,
30-
TruncatedNormal,
31-
Uniform,
32-
topological_sort,
33-
)
24+
from unite.line.config import FWHM, Flux, LineConfiguration, Redshift
25+
from unite.prior import ParameterRef, TruncatedNormal, Uniform, topological_sort
3426
from unite.spectrum import Spectra, Spectrum
3527

36-
3728
# ---------------------------------------------------------------------------
3829
# Helpers
3930
# ---------------------------------------------------------------------------
@@ -42,29 +33,24 @@
4233
def _make_spectrum(wl_range=(6400, 6700), npix=200, name='test'):
4334
"""Create a test spectrum covering the given range."""
4435
wl = np.linspace(*wl_range, npix) * u.AA
45-
disperser = SimpleDisperser(
46-
wavelength=wl.value, unit=u.AA, R=3000.0, name=name,
47-
)
36+
disperser = SimpleDisperser(wavelength=wl.value, unit=u.AA, R=3000.0, name=name)
4837
low = wl - 0.5 * np.gradient(wl)
4938
high = wl + 0.5 * np.gradient(wl)
5039
flux_unit = u.Unit('1e-17 erg / (s cm2 AA)')
5140
rng = np.random.default_rng(42)
5241
flux = (10.0 + rng.normal(0, 1, npix)) * flux_unit
5342
error = np.full(npix, 1.0) * flux_unit
5443
return Spectrum(
55-
low=low, high=high, flux=flux, error=error,
56-
disperser=disperser, name=name,
44+
low=low, high=high, flux=flux, error=error, disperser=disperser, name=name
5745
)
5846

5947

6048
def _prepare_and_build(line_config, spectra, cont_config=None):
6149
"""Prepare spectra and build model."""
6250
spectra.prepare(line_config, cont_config)
63-
spectra.compute_scales(
64-
spectra.prepared_line_config, spectra.prepared_cont_config,
65-
)
51+
spectra.compute_scales(spectra.prepared_line_config, spectra.prepared_cont_config)
6652
return model.ModelBuilder(
67-
spectra.prepared_line_config, spectra.prepared_cont_config, spectra,
53+
spectra.prepared_line_config, spectra.prepared_cont_config, spectra
6854
).build()
6955

7056

@@ -100,9 +86,7 @@ def test_four_level_chain(self):
10086
c = FWHM('c', prior=Uniform(low=b * 1.5, high=1000))
10187
d = FWHM('d', prior=Uniform(low=c + 100, high=5000))
10288

103-
named_priors = {
104-
'a': a.prior, 'b': b.prior, 'c': c.prior, 'd': d.prior,
105-
}
89+
named_priors = {'a': a.prior, 'b': b.prior, 'c': c.prior, 'd': d.prior}
10690
param_to_name = {a: 'a', b: 'b', c: 'c', d: 'd'}
10791
order = topological_sort(named_priors, param_to_name)
10892

@@ -114,16 +98,10 @@ def test_three_level_redshift_chain(self):
11498
z_nlr = Redshift(
11599
'z_nlr',
116100
prior=TruncatedNormal(
117-
loc=z_sys,
118-
scale=0.001,
119-
low=z_sys - 0.005,
120-
high=z_sys + 0.005,
101+
loc=z_sys, scale=0.001, low=z_sys - 0.005, high=z_sys + 0.005
121102
),
122103
)
123-
z_out = Redshift(
124-
'z_out',
125-
prior=Uniform(low=z_nlr - 0.01, high=z_nlr),
126-
)
104+
z_out = Redshift('z_out', prior=Uniform(low=z_nlr - 0.01, high=z_nlr))
127105

128106
named_priors = {
129107
'z_sys': z_sys.prior,
@@ -171,8 +149,7 @@ def test_redshift_token_as_loc(self):
171149
"""Redshift token passed directly as TruncatedNormal loc."""
172150
z_sys = Redshift('z_sys', prior=Uniform(-0.01, 0.01))
173151
z_nlr = Redshift(
174-
'z_nlr',
175-
prior=TruncatedNormal(loc=z_sys, scale=0.001, low=-0.02, high=0.02),
152+
'z_nlr', prior=TruncatedNormal(loc=z_sys, scale=0.001, low=-0.02, high=0.02)
176153
)
177154
assert z_sys in z_nlr.prior.dependencies()
178155

@@ -193,12 +170,7 @@ def test_all_three_bounds_depend_on_different_params(self):
193170

194171
constrained = FWHM(
195172
'constrained',
196-
prior=TruncatedNormal(
197-
loc=center,
198-
scale=50.0,
199-
low=lower,
200-
high=upper,
201-
),
173+
prior=TruncatedNormal(loc=center, scale=50.0, low=lower, high=upper),
202174
)
203175

204176
deps = constrained.prior.dependencies()
@@ -211,12 +183,7 @@ def test_loc_and_low_same_token(self):
211183
base = FWHM('base', prior=Uniform(100, 500))
212184
derived = FWHM(
213185
'derived',
214-
prior=TruncatedNormal(
215-
loc=base + 100,
216-
scale=30.0,
217-
low=base,
218-
high=2000,
219-
),
186+
prior=TruncatedNormal(loc=base + 100, scale=30.0, low=base, high=2000),
220187
)
221188

222189
deps = derived.prior.dependencies()
@@ -244,10 +211,7 @@ def test_flux_ratio_via_parameter_ref(self):
244211
f_weak = Flux(
245212
'NII_6549',
246213
prior=TruncatedNormal(
247-
loc=f_strong / 2.95,
248-
scale=0.1,
249-
low=f_strong / 4.0,
250-
high=f_strong / 2.0,
214+
loc=f_strong / 2.95, scale=0.1, low=f_strong / 4.0, high=f_strong / 2.0
251215
),
252216
)
253217

@@ -263,17 +227,11 @@ def test_flux_ratio_via_parameter_ref(self):
263227
def test_flux_chain(self):
264228
"""Three-line flux chain: Ha → [NII]6585 → [NII]6549."""
265229
f_ha = Flux('Ha', prior=Uniform(0, 20))
266-
f_nii_s = Flux(
267-
'NII_6585',
268-
prior=Uniform(low=0, high=f_ha * 2),
269-
)
230+
f_nii_s = Flux('NII_6585', prior=Uniform(low=0, high=f_ha * 2))
270231
f_nii_w = Flux(
271232
'NII_6549',
272233
prior=TruncatedNormal(
273-
loc=f_nii_s / 2.95,
274-
scale=0.05,
275-
low=f_nii_s / 4.0,
276-
high=f_nii_s / 2.0,
234+
loc=f_nii_s / 2.95, scale=0.05, low=f_nii_s / 4.0, high=f_nii_s / 2.0
277235
),
278236
)
279237

@@ -316,27 +274,15 @@ def _make_three_component_config(self):
316274
z_blr = Redshift(
317275
'z_blr',
318276
prior=TruncatedNormal(
319-
loc=z_nlr,
320-
scale=0.002,
321-
low=z_nlr - 0.01,
322-
high=z_nlr + 0.01,
277+
loc=z_nlr, scale=0.002, low=z_nlr - 0.01, high=z_nlr + 0.01
323278
),
324279
)
325-
z_out = Redshift(
326-
'z_out',
327-
prior=Uniform(low=z_nlr - 0.02, high=z_nlr),
328-
)
280+
z_out = Redshift('z_out', prior=Uniform(low=z_nlr - 0.02, high=z_nlr))
329281

330282
# -- FWHM hierarchy --
331283
fwhm_narrow = FWHM('fwhm_narrow', prior=Uniform(50, 500))
332-
fwhm_broad = FWHM(
333-
'fwhm_broad',
334-
prior=Uniform(low=fwhm_narrow + 200, high=5000),
335-
)
336-
fwhm_out = FWHM(
337-
'fwhm_out',
338-
prior=Uniform(low=fwhm_broad, high=8000),
339-
)
284+
fwhm_broad = FWHM('fwhm_broad', prior=Uniform(low=fwhm_narrow + 200, high=5000))
285+
fwhm_out = FWHM('fwhm_out', prior=Uniform(low=fwhm_broad, high=8000))
340286

341287
# -- Flux with doublet ratio --
342288
f_ha_n = Flux('Ha_n', prior=Uniform(0, 10))
@@ -346,25 +292,40 @@ def _make_three_component_config(self):
346292
f_nii_w = Flux(
347293
'NII_w',
348294
prior=TruncatedNormal(
349-
loc=f_nii_s / 2.95,
350-
scale=0.1,
351-
low=f_nii_s / 4.0,
352-
high=f_nii_s / 2.0,
295+
loc=f_nii_s / 2.95, scale=0.1, low=f_nii_s / 4.0, high=f_nii_s / 2.0
353296
),
354297
)
355298

356299
lc = LineConfiguration()
357300

358301
# Narrow lines
359-
lc.add_line('Ha', 6564.61 * u.AA, redshift=z_nlr, fwhm_gauss=fwhm_narrow, flux=f_ha_n)
360-
lc.add_line('NII_6585', 6585.27 * u.AA, redshift=z_nlr, fwhm_gauss=fwhm_narrow, flux=f_nii_s)
361-
lc.add_line('NII_6549', 6549.86 * u.AA, redshift=z_nlr, fwhm_gauss=fwhm_narrow, flux=f_nii_w)
302+
lc.add_line(
303+
'Ha', 6564.61 * u.AA, redshift=z_nlr, fwhm_gauss=fwhm_narrow, flux=f_ha_n
304+
)
305+
lc.add_line(
306+
'NII_6585',
307+
6585.27 * u.AA,
308+
redshift=z_nlr,
309+
fwhm_gauss=fwhm_narrow,
310+
flux=f_nii_s,
311+
)
312+
lc.add_line(
313+
'NII_6549',
314+
6549.86 * u.AA,
315+
redshift=z_nlr,
316+
fwhm_gauss=fwhm_narrow,
317+
flux=f_nii_w,
318+
)
362319

363320
# Broad lines
364-
lc.add_line('Ha', 6564.61 * u.AA, redshift=z_blr, fwhm_gauss=fwhm_broad, flux=f_ha_b)
321+
lc.add_line(
322+
'Ha', 6564.61 * u.AA, redshift=z_blr, fwhm_gauss=fwhm_broad, flux=f_ha_b
323+
)
365324

366325
# Outflow lines
367-
lc.add_line('Ha', 6564.61 * u.AA, redshift=z_out, fwhm_gauss=fwhm_out, flux=f_ha_out)
326+
lc.add_line(
327+
'Ha', 6564.61 * u.AA, redshift=z_out, fwhm_gauss=fwhm_out, flux=f_ha_out
328+
)
368329

369330
return lc
370331

@@ -463,20 +424,13 @@ def test_fwhm_depends_on_two_parents(self):
463424
"""A FWHM with low from one token and high from another."""
464425
lower = FWHM('lower', prior=Uniform(50, 200))
465426
upper = FWHM('upper', prior=Uniform(800, 2000))
466-
mid = FWHM(
467-
'mid',
468-
prior=Uniform(low=lower + 50, high=upper - 50),
469-
)
427+
mid = FWHM('mid', prior=Uniform(low=lower + 50, high=upper - 50))
470428

471429
deps = mid.prior.dependencies()
472430
assert lower in deps
473431
assert upper in deps
474432

475-
named_priors = {
476-
'lower': lower.prior,
477-
'upper': upper.prior,
478-
'mid': mid.prior,
479-
}
433+
named_priors = {'lower': lower.prior, 'upper': upper.prior, 'mid': mid.prior}
480434
param_to_name = {lower: 'lower', upper: 'upper', mid: 'mid'}
481435
order = topological_sort(named_priors, param_to_name)
482436
assert order.index('lower') < order.index('mid')
@@ -489,9 +443,7 @@ def test_diamond_with_convergent_child(self):
489443
c = FWHM('c', prior=Uniform(low=a, high=b))
490444
d = FWHM('d', prior=Uniform(low=a + 50, high=b - 50))
491445

492-
named_priors = {
493-
'a': a.prior, 'b': b.prior, 'c': c.prior, 'd': d.prior,
494-
}
446+
named_priors = {'a': a.prior, 'b': b.prior, 'c': c.prior, 'd': d.prior}
495447
param_to_name = {a: 'a', b: 'b', c: 'c', d: 'd'}
496448
order = topological_sort(named_priors, param_to_name)
497449

@@ -543,10 +495,7 @@ def test_nested_ref_in_truncated_normal(self):
543495
derived = FWHM(
544496
'derived',
545497
prior=TruncatedNormal(
546-
loc=base * 1.5 + 50,
547-
scale=30.0,
548-
low=base + 20,
549-
high=base * 3,
498+
loc=base * 1.5 + 50, scale=30.0, low=base + 20, high=base * 3
550499
),
551500
)
552501
context = {base: 200.0}
@@ -595,10 +544,7 @@ def test_flux_ratio_serialization(self):
595544
f_weak = Flux(
596545
'f_w',
597546
prior=TruncatedNormal(
598-
loc=f_strong / 2.95,
599-
scale=0.1,
600-
low=f_strong / 4.0,
601-
high=f_strong / 2.0,
547+
loc=f_strong / 2.95, scale=0.1, low=f_strong / 4.0, high=f_strong / 2.0
602548
),
603549
)
604550

@@ -651,10 +597,7 @@ class TestEndToEndDeepDependencies:
651597
def test_narrow_broad_model_respects_ordering(self):
652598
"""Verify sampled broad FWHM > sampled narrow FWHM + offset."""
653599
fwhm_narrow = FWHM('fwhm_narrow', prior=Uniform(50, 300))
654-
fwhm_broad = FWHM(
655-
'fwhm_broad',
656-
prior=Uniform(low=fwhm_narrow + 200, high=3000),
657-
)
600+
fwhm_broad = FWHM('fwhm_broad', prior=Uniform(low=fwhm_narrow + 200, high=3000))
658601

659602
lc = LineConfiguration()
660603
z = Redshift('z', prior=Uniform(-0.005, 0.005))
@@ -709,10 +652,7 @@ def test_redshift_hierarchy_model(self):
709652
z_nlr = Redshift(
710653
'z_nlr',
711654
prior=TruncatedNormal(
712-
loc=z_sys,
713-
scale=0.001,
714-
low=z_sys - 0.003,
715-
high=z_sys + 0.003,
655+
loc=z_sys, scale=0.001, low=z_sys - 0.003, high=z_sys + 0.003
716656
),
717657
)
718658

tests/test_disperser_base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from unite.disperser.base import Disperser, FluxScale, PixOffset, RScale
77
from unite.prior import Fixed, TruncatedNormal, Uniform
88

9-
109
# ---------------------------------------------------------------------------
1110
# Calibration token construction
1211
# ---------------------------------------------------------------------------
@@ -120,9 +119,7 @@ def test_has_calibration_params_false(self):
120119
from unite.disperser.generic import GenericDisperser
121120

122121
d = GenericDisperser(
123-
R_func=lambda w: w * 0 + 1000,
124-
dlam_dpix_func=lambda w: w / 1000,
125-
unit=u.AA,
122+
R_func=lambda w: w * 0 + 1000, dlam_dpix_func=lambda w: w / 1000, unit=u.AA
126123
)
127124
assert not d.has_calibration_params
128125

0 commit comments

Comments
 (0)