33from __future__ import annotations
44
55from collections .abc import Iterator , Sequence
6+ from dataclasses import dataclass , field
67from typing import TYPE_CHECKING
78
89import jax .numpy as jnp
@@ -284,6 +285,87 @@ def __repr__(self) -> str:
284285 return f'{ label } : { self .npix } px, λ ∈ [{ lo :.4g} , { hi :.4g} ] { unit_str } { cal } '
285286
286287
288+ # ---------------------------------------------------------------------------
289+ # ---------------------------------------------------------------------------
290+ # Scale diagnostics dataclasses
291+ # ---------------------------------------------------------------------------
292+
293+
294+ @dataclass
295+ class RegionDiagnostic :
296+ """Diagnostics for a single continuum region fit.
297+
298+ Attributes
299+ ----------
300+ obs_low : float
301+ Observed-frame lower bound of the region (disperser's unit).
302+ obs_high : float
303+ Observed-frame upper bound of the region (disperser's unit).
304+ in_region : jnp.ndarray
305+ Boolean mask of shape ``(npix,)`` selecting all pixels inside the
306+ region bounds.
307+ good_mask : jnp.ndarray
308+ Boolean mask of shape ``(npix,)`` selecting pixels used for
309+ fitting (``in_region & ~line_mask``).
310+ model_on_region : jnp.ndarray or None
311+ Best-fit continuum model evaluated at the ``in_region`` pixels.
312+ ``None`` if the fit failed (too few unmasked pixels).
313+ chi2_red : float or None
314+ Reduced chi-squared of the fit, or ``None`` if the fit failed.
315+ fit_params : dict of str to float
316+ Best-fit parameter dict returned by :func:`~unite.continuum.fit.fit_continuum_form`.
317+ Empty if the fit failed.
318+ """
319+
320+ obs_low : float
321+ obs_high : float
322+ in_region : jnp .ndarray
323+ good_mask : jnp .ndarray
324+ model_on_region : jnp .ndarray | None
325+ chi2_red : float | None
326+ fit_params : dict = field (default_factory = dict )
327+
328+
329+ @dataclass
330+ class SpectrumScaleDiagnostic :
331+ """Diagnostics for one spectrum produced by :meth:`Spectra.compute_scales`.
332+
333+ Attributes
334+ ----------
335+ name : str
336+ Spectrum name (from :attr:`Spectrum.name`).
337+ wavelength : jnp.ndarray
338+ Pixel-centre wavelengths (disperser's unit), shape ``(npix,)``.
339+ flux : jnp.ndarray
340+ Observed flux values, shape ``(npix,)``.
341+ error : jnp.ndarray
342+ Flux uncertainty values, shape ``(npix,)``.
343+ line_mask : jnp.ndarray
344+ Boolean mask of shape ``(npix,)``; ``True`` where a pixel was
345+ excluded because it lies near an emission line.
346+ continuum_model : jnp.ndarray
347+ Full-spectrum continuum model array of shape ``(npix,)``.
348+ Pixels not covered by any continuum region are ``NaN``.
349+ regions : list of RegionDiagnostic
350+ Per-region fit diagnostics (one entry per region that overlaps
351+ this spectrum).
352+ flux_unit : astropy.units.UnitBase
353+ Flux density unit of *flux* and *error*.
354+ wavelength_unit : astropy.units.UnitBase
355+ Wavelength unit of *wavelength*.
356+ """
357+
358+ name : str
359+ wavelength : jnp .ndarray
360+ flux : jnp .ndarray
361+ error : jnp .ndarray
362+ line_mask : jnp .ndarray
363+ continuum_model : jnp .ndarray
364+ regions : list
365+ flux_unit : object
366+ wavelength_unit : object
367+
368+
287369# ---------------------------------------------------------------------------
288370# ---------------------------------------------------------------------------
289371# Spectra collection
@@ -339,6 +421,7 @@ def __init__(
339421 self ._is_prepared : bool = False
340422 self ._prepared_line_config : LineConfiguration | None = None
341423 self ._prepared_cont_config : ContinuumConfiguration | None = None
424+ self ._scale_diagnostics : list [SpectrumScaleDiagnostic ] | None = None
342425
343426 # Canonical wavelength unit: default to the first spectrum's unit.
344427 if canonical_unit is not None :
@@ -420,6 +503,17 @@ def continuum_scale(self, value: u.Quantity) -> None:
420503 raise ValueError (msg )
421504 self ._continuum_scale = value
422505
506+ @property
507+ def scale_diagnostics (self ) -> list [SpectrumScaleDiagnostic ] | None :
508+ """Per-spectrum diagnostics from the most recent :meth:`compute_scales` call.
509+
510+ Returns a list of :class:`SpectrumScaleDiagnostic` objects (one per
511+ spectrum), each holding the line mask, the fitted continuum model
512+ array, and per-region fit details. ``None`` if :meth:`compute_scales`
513+ has not been called yet.
514+ """
515+ return self ._scale_diagnostics
516+
423517 def compute_scales (
424518 self ,
425519 line_config : LineConfiguration ,
@@ -500,8 +594,9 @@ def _build_line_mask(spectrum, fwhm_kms):
500594 def _fit_continuum_region (wl , flux , error , obs_low , obs_high , line_mask , form ):
501595 """Fit the region's continuum form to unmasked pixels.
502596
503- Returns (model_values_on_region, good_mask, chi2_red).
504- model_values_on_region covers all pixels in_region.
597+ Returns (model_on_region, in_region, good, chi2_red, fit_params).
598+ model_on_region covers all pixels in_region; good is the
599+ unmasked pixel mask used for fitting.
505600 """
506601 from unite .continuum .fit import fit_continuum_form
507602
@@ -511,7 +606,7 @@ def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
511606
512607 min_params = form .n_params # includes normalization_wavelength
513608 if n_good < max (min_params + 1 , 3 ):
514- return None , in_region , None
609+ return None , in_region , good , None , {}
515610
516611 center = float ((obs_low + obs_high ) / 2.0 )
517612 adapted_form = form ._adapt_for_observed_region (obs_low , obs_high )
@@ -522,7 +617,7 @@ def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
522617 # Evaluate on all in-region pixels.
523618 model_region = adapted_form .evaluate (wl [in_region ], center , result .params )
524619
525- return model_region , in_region , result .chi2_red
620+ return model_region , in_region , good , result .chi2_red , result . params
526621
527622 # --- Line scale (with continuum subtraction when available) ---
528623 max_line_scale = 0.0
@@ -550,7 +645,7 @@ def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
550645 line_mask ,
551646 region .form ,
552647 )
553- model_region , in_region , _ = result
648+ model_region , in_region , _good , _ , _ = result
554649 if model_region is not None :
555650 continuum_est = continuum_est .at [in_region ].set (model_region )
556651
@@ -612,7 +707,7 @@ def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
612707 line_mask ,
613708 region .form ,
614709 )
615- _ , fit_region , chi2_red = result
710+ _ , fit_region , _good , chi2_red , _ = result
616711 if chi2_red is not None :
617712 all_chi2_reds .append (chi2_red )
618713 region_scale = float (jnp .sqrt (jnp .maximum (chi2_red , 1.0 )))
@@ -634,6 +729,66 @@ def _fit_continuum_region(wl, flux, error, obs_low, obs_high, line_mask, form):
634729 cont_scale_val = max_cont_scale if max_cont_scale > 0 else 1.0
635730 self ._continuum_scale = cont_scale_val * ref_flux_unit
636731
732+ # --- Collect diagnostics ---
733+ diag_list = []
734+ for spectrum in self ._spectra :
735+ wl = spectrum .wavelength
736+ line_mask = _build_line_mask (spectrum , mask_fwhm_kms )
737+ continuum_model_full = jnp .full (spectrum .npix , jnp .nan )
738+ region_diags : list [RegionDiagnostic ] = []
739+
740+ if continuum_config is not None :
741+ for region in continuum_config :
742+ conv = _wavelength_conversion_factor (region ._unit , spectrum .unit )
743+ obs_low = region .low * conv * (1.0 + z )
744+ obs_high = region .high * conv * (1.0 + z )
745+
746+ in_region = (wl >= obs_low ) & (wl <= obs_high )
747+ if not jnp .any (in_region ):
748+ continue
749+
750+ model_region , in_region , good , chi2_red , fit_params = (
751+ _fit_continuum_region (
752+ wl ,
753+ spectrum .flux ,
754+ spectrum .error ,
755+ obs_low ,
756+ obs_high ,
757+ line_mask ,
758+ region .form ,
759+ )
760+ )
761+ if model_region is not None :
762+ continuum_model_full = continuum_model_full .at [in_region ].set (
763+ model_region
764+ )
765+ region_diags .append (
766+ RegionDiagnostic (
767+ obs_low = float (obs_low ),
768+ obs_high = float (obs_high ),
769+ in_region = in_region ,
770+ good_mask = good ,
771+ model_on_region = model_region ,
772+ chi2_red = chi2_red ,
773+ fit_params = fit_params ,
774+ )
775+ )
776+
777+ diag_list .append (
778+ SpectrumScaleDiagnostic (
779+ name = spectrum .name ,
780+ wavelength = wl ,
781+ flux = spectrum .flux ,
782+ error = spectrum .error ,
783+ line_mask = line_mask ,
784+ continuum_model = continuum_model_full ,
785+ regions = region_diags ,
786+ flux_unit = spectrum .flux_unit ,
787+ wavelength_unit = spectrum .unit ,
788+ )
789+ )
790+ self ._scale_diagnostics = diag_list
791+
637792 # -- preparation ----------------------------------------------------------
638793
639794 def prepare (
0 commit comments