Skip to content

Commit 35b9926

Browse files
committed
Added continuum diagnostic after calculating scales
1 parent 336717f commit 35b9926

File tree

4 files changed

+205
-11
lines changed

4 files changed

+205
-11
lines changed

docs/guides/spectra.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,45 @@ for spec in spectra:
141141
print(spec.name, spec.error_scale)
142142
```
143143

144+
#### Inspecting the Continuum Fit
145+
146+
After calling `compute_scales`, the fitted continuum model and per-region diagnostics are
147+
available via `Spectra.scale_diagnostics`. This is a list of
148+
`SpectrumScaleDiagnostic` objects — one per spectrum — each containing:
149+
150+
| Attribute | Description |
151+
|---|---|
152+
| `wavelength` | Pixel-centre wavelengths (disperser unit) |
153+
| `flux` / `error` | Observed flux and uncertainty arrays |
154+
| `line_mask` | Boolean array — `True` where a pixel was excluded near an emission line |
155+
| `continuum_model` | Full-length continuum model array; `NaN` outside any fitted region |
156+
| `regions` | List of `RegionDiagnostic` objects, one per continuum region |
157+
158+
Each `RegionDiagnostic` holds `obs_low`, `obs_high`, `in_region`, `good_mask`,
159+
`model_on_region`, `chi2_red`, and `fit_params`.
160+
161+
A typical inspection loop:
162+
163+
```python
164+
spectra.compute_scales(filtered_lines, filtered_cont, error_scale=True)
165+
166+
import numpy as np
167+
for diag in spectra.scale_diagnostics:
168+
wl = np.asarray(diag.wavelength)
169+
flux = np.asarray(diag.flux)
170+
cont = np.asarray(diag.continuum_model) # NaN outside regions
171+
mask = np.asarray(diag.line_mask)
172+
173+
for rinfo in diag.regions:
174+
good = np.asarray(rinfo.good_mask)
175+
model = np.asarray(rinfo.model_on_region) # evaluated on in_region pixels
176+
print(f' chi2_red = {rinfo.chi2_red:.2f}')
177+
```
178+
179+
See `examples/scale_diagnostic_example.py` for a complete plotting script that renders
180+
three-panel figures (spectrum + fit, residuals in σ, residual histogram) for every
181+
continuum region in every spectrum.
182+
144183
### 3. ModelBuilder.build()
145184

146185
```python

pixi.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description = "Unified liNe Intagration Turbo Engine"
44
readme = "README.md"
55
authors = [{name = "Raphael Erik Hviding", email = "raphael.hviding@gmail.com"}]
66
requires-python = ">= 3.12"
7-
version = "1.1.1"
7+
version = "1.2.0"
88
dependencies = [
99
"numpyro>=0.20.0,<0.21",
1010
"astropy>=7.2.0,<8",

unite/spectrum/spectrum.py

Lines changed: 161 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from collections.abc import Iterator, Sequence
6+
from dataclasses import dataclass, field
67
from typing import TYPE_CHECKING
78

89
import jax.numpy as jnp
@@ -284,6 +285,87 @@ def __repr__(self) -> str:
284285
return f'{label}: {self.npix} px, λ ∈ [{lo:.4g}, {hi:.4g}] {unit_str}{cal}'
285286

286287

288+
# ---------------------------------------------------------------------------
289+
# ---------------------------------------------------------------------------
290+
# Scale diagnostics dataclasses
291+
# ---------------------------------------------------------------------------
292+
293+
294+
@dataclass
295+
class RegionDiagnostic:
296+
"""Diagnostics for a single continuum region fit.
297+
298+
Attributes
299+
----------
300+
obs_low : float
301+
Observed-frame lower bound of the region (disperser's unit).
302+
obs_high : float
303+
Observed-frame upper bound of the region (disperser's unit).
304+
in_region : jnp.ndarray
305+
Boolean mask of shape ``(npix,)`` selecting all pixels inside the
306+
region bounds.
307+
good_mask : jnp.ndarray
308+
Boolean mask of shape ``(npix,)`` selecting pixels used for
309+
fitting (``in_region & ~line_mask``).
310+
model_on_region : jnp.ndarray or None
311+
Best-fit continuum model evaluated at the ``in_region`` pixels.
312+
``None`` if the fit failed (too few unmasked pixels).
313+
chi2_red : float or None
314+
Reduced chi-squared of the fit, or ``None`` if the fit failed.
315+
fit_params : dict of str to float
316+
Best-fit parameter dict returned by :func:`~unite.continuum.fit.fit_continuum_form`.
317+
Empty if the fit failed.
318+
"""
319+
320+
obs_low: float
321+
obs_high: float
322+
in_region: jnp.ndarray
323+
good_mask: jnp.ndarray
324+
model_on_region: jnp.ndarray | None
325+
chi2_red: float | None
326+
fit_params: dict = field(default_factory=dict)
327+
328+
329+
@dataclass
330+
class SpectrumScaleDiagnostic:
331+
"""Diagnostics for one spectrum produced by :meth:`Spectra.compute_scales`.
332+
333+
Attributes
334+
----------
335+
name : str
336+
Spectrum name (from :attr:`Spectrum.name`).
337+
wavelength : jnp.ndarray
338+
Pixel-centre wavelengths (disperser's unit), shape ``(npix,)``.
339+
flux : jnp.ndarray
340+
Observed flux values, shape ``(npix,)``.
341+
error : jnp.ndarray
342+
Flux uncertainty values, shape ``(npix,)``.
343+
line_mask : jnp.ndarray
344+
Boolean mask of shape ``(npix,)``; ``True`` where a pixel was
345+
excluded because it lies near an emission line.
346+
continuum_model : jnp.ndarray
347+
Full-spectrum continuum model array of shape ``(npix,)``.
348+
Pixels not covered by any continuum region are ``NaN``.
349+
regions : list of RegionDiagnostic
350+
Per-region fit diagnostics (one entry per region that overlaps
351+
this spectrum).
352+
flux_unit : astropy.units.UnitBase
353+
Flux density unit of *flux* and *error*.
354+
wavelength_unit : astropy.units.UnitBase
355+
Wavelength unit of *wavelength*.
356+
"""
357+
358+
name: str
359+
wavelength: jnp.ndarray
360+
flux: jnp.ndarray
361+
error: jnp.ndarray
362+
line_mask: jnp.ndarray
363+
continuum_model: jnp.ndarray
364+
regions: list
365+
flux_unit: object
366+
wavelength_unit: object
367+
368+
287369
# ---------------------------------------------------------------------------
288370
# ---------------------------------------------------------------------------
289371
# Spectra collection
@@ -339,6 +421,7 @@ def __init__(
339421
self._is_prepared: bool = False
340422
self._prepared_line_config: LineConfiguration | None = None
341423
self._prepared_cont_config: ContinuumConfiguration | None = None
424+
self._scale_diagnostics: list[SpectrumScaleDiagnostic] | None = None
342425

343426
# Canonical wavelength unit: default to the first spectrum's unit.
344427
if canonical_unit is not None:
@@ -420,6 +503,17 @@ def continuum_scale(self, value: u.Quantity) -> None:
420503
raise ValueError(msg)
421504
self._continuum_scale = value
422505

506+
@property
507+
def scale_diagnostics(self) -> list[SpectrumScaleDiagnostic] | None:
508+
"""Per-spectrum diagnostics from the most recent :meth:`compute_scales` call.
509+
510+
Returns a list of :class:`SpectrumScaleDiagnostic` objects (one per
511+
spectrum), each holding the line mask, the fitted continuum model
512+
array, and per-region fit details. ``None`` if :meth:`compute_scales`
513+
has not been called yet.
514+
"""
515+
return self._scale_diagnostics
516+
423517
def compute_scales(
424518
self,
425519
line_config: LineConfiguration,
@@ -500,8 +594,9 @@ def _build_line_mask(spectrum, fwhm_kms):
500594
def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
501595
"""Fit the region's continuum form to unmasked pixels.
502596
503-
Returns (model_values_on_region, good_mask, chi2_red).
504-
model_values_on_region covers all pixels in_region.
597+
Returns (model_on_region, in_region, good, chi2_red, fit_params).
598+
model_on_region covers all pixels in_region; good is the
599+
unmasked pixel mask used for fitting.
505600
"""
506601
from unite.continuum.fit import fit_continuum_form
507602

@@ -511,7 +606,7 @@ def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
511606

512607
min_params = form.n_params # includes normalization_wavelength
513608
if n_good < max(min_params + 1, 3):
514-
return None, in_region, None
609+
return None, in_region, good, None, {}
515610

516611
center = float((obs_low + obs_high) / 2.0)
517612
adapted_form = form._adapt_for_observed_region(obs_low, obs_high)
@@ -522,7 +617,7 @@ def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
522617
# Evaluate on all in-region pixels.
523618
model_region = adapted_form.evaluate(wl[in_region], center, result.params)
524619

525-
return model_region, in_region, result.chi2_red
620+
return model_region, in_region, good, result.chi2_red, result.params
526621

527622
# --- Line scale (with continuum subtraction when available) ---
528623
max_line_scale = 0.0
@@ -550,7 +645,7 @@ def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
550645
line_mask,
551646
region.form,
552647
)
553-
model_region, in_region, _ = result
648+
model_region, in_region, _good, _, _ = result
554649
if model_region is not None:
555650
continuum_est = continuum_est.at[in_region].set(model_region)
556651

@@ -612,7 +707,7 @@ def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
612707
line_mask,
613708
region.form,
614709
)
615-
_, fit_region, chi2_red = result
710+
_, fit_region, _good, chi2_red, _ = result
616711
if chi2_red is not None:
617712
all_chi2_reds.append(chi2_red)
618713
region_scale = float(jnp.sqrt(jnp.maximum(chi2_red, 1.0)))
@@ -634,6 +729,66 @@ def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
634729
cont_scale_val = max_cont_scale if max_cont_scale > 0 else 1.0
635730
self._continuum_scale = cont_scale_val * ref_flux_unit
636731

732+
# --- Collect diagnostics ---
733+
diag_list = []
734+
for spectrum in self._spectra:
735+
wl = spectrum.wavelength
736+
line_mask = _build_line_mask(spectrum, mask_fwhm_kms)
737+
continuum_model_full = jnp.full(spectrum.npix, jnp.nan)
738+
region_diags: list[RegionDiagnostic] = []
739+
740+
if continuum_config is not None:
741+
for region in continuum_config:
742+
conv = _wavelength_conversion_factor(region._unit, spectrum.unit)
743+
obs_low = region.low * conv * (1.0 + z)
744+
obs_high = region.high * conv * (1.0 + z)
745+
746+
in_region = (wl >= obs_low) & (wl <= obs_high)
747+
if not jnp.any(in_region):
748+
continue
749+
750+
model_region, in_region, good, chi2_red, fit_params = (
751+
_fit_continuum_region(
752+
wl,
753+
spectrum.flux,
754+
spectrum.error,
755+
obs_low,
756+
obs_high,
757+
line_mask,
758+
region.form,
759+
)
760+
)
761+
if model_region is not None:
762+
continuum_model_full = continuum_model_full.at[in_region].set(
763+
model_region
764+
)
765+
region_diags.append(
766+
RegionDiagnostic(
767+
obs_low=float(obs_low),
768+
obs_high=float(obs_high),
769+
in_region=in_region,
770+
good_mask=good,
771+
model_on_region=model_region,
772+
chi2_red=chi2_red,
773+
fit_params=fit_params,
774+
)
775+
)
776+
777+
diag_list.append(
778+
SpectrumScaleDiagnostic(
779+
name=spectrum.name,
780+
wavelength=wl,
781+
flux=spectrum.flux,
782+
error=spectrum.error,
783+
line_mask=line_mask,
784+
continuum_model=continuum_model_full,
785+
regions=region_diags,
786+
flux_unit=spectrum.flux_unit,
787+
wavelength_unit=spectrum.unit,
788+
)
789+
)
790+
self._scale_diagnostics = diag_list
791+
637792
# -- preparation ----------------------------------------------------------
638793

639794
def prepare(

0 commit comments

Comments
 (0)