|
| 1 | +# SPDX-License-Identifier: BSD-3-Clause |
| 2 | +# /usr/bin/env python3 |
| 3 | +# -*- coding: utf-8 -*- |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +__coding__ = "utf-8" |
| 8 | +__authors__ = ["Brian R. Pauw"] |
| 9 | +__copyright__ = "Copyright 2025, The MoDaCor team" |
| 10 | +__date__ = "12/12/2025" |
| 11 | +__status__ = "Development" |
| 12 | + |
| 13 | +__all__ = ["FindScaleFactor1D"] |
| 14 | +__version__ = "20251212.2" |
| 15 | + |
| 16 | +from pathlib import Path |
| 17 | +from typing import Dict |
| 18 | + |
| 19 | +import numpy as np |
| 20 | +from attrs import define |
| 21 | +from scipy.interpolate import interp1d |
| 22 | +from scipy.optimize import least_squares |
| 23 | + |
| 24 | +from modacor import ureg |
| 25 | +from modacor.dataclasses.basedata import BaseData |
| 26 | +from modacor.dataclasses.databundle import DataBundle |
| 27 | +from modacor.dataclasses.process_step import ProcessStep |
| 28 | +from modacor.dataclasses.process_step_describer import ProcessStepDescriber |
| 29 | + |
| 30 | +# ------------------------------------------------------------------------- |
| 31 | +# Small data containers (attrs, not namedtuple) |
| 32 | +# ------------------------------------------------------------------------- |
| 33 | + |
| 34 | + |
| 35 | +@define(slots=True) |
| 36 | +class DependentData1D: |
| 37 | + y: np.ndarray |
| 38 | + sigma: np.ndarray |
| 39 | + weights: np.ndarray |
| 40 | + |
| 41 | + |
| 42 | +@define(slots=True) |
| 43 | +class FitData1D: |
| 44 | + x: np.ndarray |
| 45 | + y_ref: np.ndarray |
| 46 | + y_work: np.ndarray |
| 47 | + sigma_ref: np.ndarray |
| 48 | + sigma_work: np.ndarray |
| 49 | + weights: np.ndarray |
| 50 | + |
| 51 | + |
| 52 | +# ------------------------------------------------------------------------- |
| 53 | +# Helpers |
| 54 | +# ------------------------------------------------------------------------- |
| 55 | + |
| 56 | + |
| 57 | +def _combined_sigma(bd: BaseData) -> np.ndarray: |
| 58 | + if not bd.uncertainties: |
| 59 | + return np.asarray(1.0) |
| 60 | + |
| 61 | + sig2 = None |
| 62 | + for u in bd.uncertainties.values(): |
| 63 | + arr = np.asarray(u, dtype=float) |
| 64 | + sig2 = arr * arr if sig2 is None else sig2 + arr * arr |
| 65 | + return np.sqrt(sig2) |
| 66 | + |
| 67 | + |
| 68 | +def _extract_dependent(bd: BaseData) -> DependentData1D: |
| 69 | + if bd.rank_of_data != 1: |
| 70 | + raise ValueError("Dependent BaseData must be rank-1.") |
| 71 | + |
| 72 | + y = np.asarray(bd.signal, dtype=float).squeeze() |
| 73 | + if y.ndim != 1: |
| 74 | + raise ValueError("Dependent signal must be 1D.") |
| 75 | + |
| 76 | + sigma = np.asarray(_combined_sigma(bd), dtype=float) |
| 77 | + weights = np.asarray(bd.weights, dtype=float) |
| 78 | + |
| 79 | + if sigma.size == 1: |
| 80 | + sigma = np.full_like(y, float(sigma)) |
| 81 | + else: |
| 82 | + sigma = sigma.squeeze() |
| 83 | + |
| 84 | + if weights.size == 1: |
| 85 | + weights = np.full_like(y, float(weights)) |
| 86 | + else: |
| 87 | + weights = weights.squeeze() |
| 88 | + |
| 89 | + if sigma.shape != y.shape or weights.shape != y.shape: |
| 90 | + raise ValueError("Uncertainties and weights must match dependent signal shape.") |
| 91 | + |
| 92 | + sigma = np.where(sigma <= 0.0, np.nan, sigma) |
| 93 | + |
| 94 | + return DependentData1D(y=y, sigma=sigma, weights=weights) |
| 95 | + |
| 96 | + |
| 97 | +def _overlap_range(x1: np.ndarray, x2: np.ndarray) -> tuple[float, float]: |
| 98 | + return float(max(np.nanmin(x1), np.nanmin(x2))), float(min(np.nanmax(x1), np.nanmax(x2))) |
| 99 | + |
| 100 | + |
| 101 | +def _prepare_fit_data( |
| 102 | + *, |
| 103 | + x_work: np.ndarray, |
| 104 | + dep_work: DependentData1D, |
| 105 | + x_ref: np.ndarray, |
| 106 | + dep_ref: DependentData1D, |
| 107 | + require_overlap: bool, |
| 108 | + interpolation_kind: str, |
| 109 | + fit_min: float, |
| 110 | + fit_max: float, |
| 111 | + use_weights: bool, |
| 112 | +) -> FitData1D: |
| 113 | + ov_min, ov_max = _overlap_range(x_ref, x_work) |
| 114 | + if require_overlap and not (ov_min < ov_max): |
| 115 | + raise ValueError("No overlap between working and reference x-axes.") |
| 116 | + |
| 117 | + lo = max(fit_min, ov_min) if require_overlap else fit_min |
| 118 | + hi = min(fit_max, ov_max) if require_overlap else fit_max |
| 119 | + if not lo < hi: |
| 120 | + raise ValueError("Empty fit range after overlap constraints.") |
| 121 | + |
| 122 | + mask = (x_ref >= lo) & (x_ref <= hi) |
| 123 | + if np.count_nonzero(mask) < 2: |
| 124 | + raise ValueError("Not enough points in fit window.") |
| 125 | + |
| 126 | + x_fit = x_ref[mask] |
| 127 | + y_ref = dep_ref.y[mask] |
| 128 | + sigma_ref = dep_ref.sigma[mask] |
| 129 | + weights_ref = dep_ref.weights[mask] |
| 130 | + |
| 131 | + # sort working data |
| 132 | + order = np.argsort(x_work) |
| 133 | + x_work = x_work[order] |
| 134 | + y_work = dep_work.y[order] |
| 135 | + sigma_work = dep_work.sigma[order] |
| 136 | + weights_work = dep_work.weights[order] |
| 137 | + |
| 138 | + bounds_error = require_overlap |
| 139 | + fill_value = None if bounds_error else "extrapolate" |
| 140 | + |
| 141 | + interp_y = interp1d( |
| 142 | + x_work, y_work, kind=interpolation_kind, bounds_error=bounds_error, fill_value=fill_value, assume_sorted=True |
| 143 | + ) |
| 144 | + interp_sigma = interp1d( |
| 145 | + x_work, sigma_work, kind="linear", bounds_error=bounds_error, fill_value=fill_value, assume_sorted=True |
| 146 | + ) |
| 147 | + interp_w = interp1d( |
| 148 | + x_work, weights_work, kind="linear", bounds_error=bounds_error, fill_value=fill_value, assume_sorted=True |
| 149 | + ) |
| 150 | + |
| 151 | + y_work_i = interp_y(x_fit) |
| 152 | + sigma_work_i = interp_sigma(x_fit) |
| 153 | + weights_work_i = interp_w(x_fit) |
| 154 | + |
| 155 | + weights = (weights_ref * weights_work_i) if use_weights else np.ones_like(y_ref) |
| 156 | + |
| 157 | + valid = ( |
| 158 | + np.isfinite(y_ref) |
| 159 | + & np.isfinite(y_work_i) |
| 160 | + & np.isfinite(sigma_ref) |
| 161 | + & (sigma_ref > 0) |
| 162 | + & np.isfinite(sigma_work_i) |
| 163 | + & (sigma_work_i >= 0) |
| 164 | + & np.isfinite(weights) |
| 165 | + & (weights > 0) |
| 166 | + ) |
| 167 | + |
| 168 | + if np.count_nonzero(valid) < 2: |
| 169 | + raise ValueError("Not enough valid points after masking.") |
| 170 | + |
| 171 | + return FitData1D( |
| 172 | + x=x_fit[valid], |
| 173 | + y_ref=y_ref[valid], |
| 174 | + y_work=y_work_i[valid], |
| 175 | + sigma_ref=sigma_ref[valid], |
| 176 | + sigma_work=sigma_work_i[valid], |
| 177 | + weights=weights[valid], |
| 178 | + ) |
| 179 | + |
| 180 | + |
| 181 | +# ------------------------------------------------------------------------- |
| 182 | +# Main ProcessStep |
| 183 | +# ------------------------------------------------------------------------- |
| 184 | + |
| 185 | + |
| 186 | +class FindScaleFactor1D(ProcessStep): |
| 187 | + documentation = ProcessStepDescriber( |
| 188 | + calling_name="Scale 1D curve to reference (compute-only)", |
| 189 | + calling_id="FindScaleFactor1D", |
| 190 | + calling_module_path=Path(__file__), |
| 191 | + calling_version=__version__, |
| 192 | + required_data_keys=["signal"], |
| 193 | + modifies={ |
| 194 | + "scale_factor": ["signal", "uncertainties", "units"], |
| 195 | + "scale_background": ["signal", "uncertainties", "units"], |
| 196 | + }, |
| 197 | + calling_arguments={ |
| 198 | + "signal_key": "signal", |
| 199 | + "independent_axis_key": "Q", |
| 200 | + "scale_output_key": "scale_factor", |
| 201 | + "background_output_key": "scale_background", |
| 202 | + "fit_background": False, |
| 203 | + "fit_min_val": None, |
| 204 | + "fit_max_val": None, |
| 205 | + "fit_val_units": None, |
| 206 | + "require_overlap": True, |
| 207 | + "interpolation_kind": "linear", |
| 208 | + "robust_loss": "huber", |
| 209 | + "robust_fscale": 1.0, |
| 210 | + "use_basedata_weights": True, |
| 211 | + }, |
| 212 | + step_keywords=["scale", "calibration", "1D"], |
| 213 | + step_doc="Compute scale factor between two 1D curves using robust least squares.", |
| 214 | + ) |
| 215 | + |
| 216 | + def calculate(self) -> Dict[str, DataBundle]: |
| 217 | + cfg = self.configuration |
| 218 | + work_key, ref_key = cfg["with_processing_keys"] |
| 219 | + |
| 220 | + sig_key = cfg.get("signal_key", "signal") |
| 221 | + axis_key = cfg.get("independent_axis_key", "Q") |
| 222 | + |
| 223 | + work_db = self.processing_data[work_key] |
| 224 | + ref_db = self.processing_data[ref_key] |
| 225 | + |
| 226 | + y_work_bd = work_db[sig_key].copy(with_axes=True) |
| 227 | + y_ref_bd = ref_db[sig_key].copy(with_axes=True) |
| 228 | + |
| 229 | + x_work_bd = work_db[axis_key].copy(with_axes=False) |
| 230 | + x_ref_bd = ref_db[axis_key].copy(with_axes=False) |
| 231 | + |
| 232 | + if x_work_bd.units != x_ref_bd.units: |
| 233 | + x_work_bd.to_units(x_ref_bd.units) |
| 234 | + |
| 235 | + x_work = np.asarray(x_work_bd.signal, dtype=float).squeeze() |
| 236 | + x_ref = np.asarray(x_ref_bd.signal, dtype=float).squeeze() |
| 237 | + |
| 238 | + dep_work = _extract_dependent(y_work_bd) |
| 239 | + dep_ref = _extract_dependent(y_ref_bd) |
| 240 | + |
| 241 | + fit_min = cfg.get("fit_min_val") |
| 242 | + fit_max = cfg.get("fit_max_val") |
| 243 | + |
| 244 | + fit_units = cfg.get("fit_val_units") or x_ref_bd.units |
| 245 | + if fit_min is not None: |
| 246 | + fit_min = ureg.Quantity(fit_min, fit_units).to(x_ref_bd.units).magnitude |
| 247 | + else: |
| 248 | + fit_min = np.nanmin(x_ref) |
| 249 | + |
| 250 | + if fit_max is not None: |
| 251 | + fit_max = ureg.Quantity(fit_max, fit_units).to(x_ref_bd.units).magnitude |
| 252 | + else: |
| 253 | + fit_max = np.nanmax(x_ref) |
| 254 | + |
| 255 | + fit_data = _prepare_fit_data( |
| 256 | + x_work=x_work, |
| 257 | + dep_work=dep_work, |
| 258 | + x_ref=x_ref, |
| 259 | + dep_ref=dep_ref, |
| 260 | + require_overlap=cfg.get("require_overlap", True), |
| 261 | + interpolation_kind=cfg.get("interpolation_kind", "linear"), |
| 262 | + fit_min=float(fit_min), |
| 263 | + fit_max=float(fit_max), |
| 264 | + use_weights=cfg.get("use_basedata_weights", True), |
| 265 | + ) |
| 266 | + |
| 267 | + fit_background = bool(cfg.get("fit_background", False)) |
| 268 | + |
| 269 | + def residuals(p: np.ndarray) -> np.ndarray: |
| 270 | + scale = p[0] |
| 271 | + background = p[1] if fit_background else 0.0 |
| 272 | + model = scale * fit_data.y_work + background |
| 273 | + sigma = np.sqrt(fit_data.sigma_ref**2 + (scale * fit_data.sigma_work) ** 2) |
| 274 | + r = (fit_data.y_ref - model) / sigma |
| 275 | + return np.sqrt(fit_data.weights) * r |
| 276 | + |
| 277 | + if fit_background: |
| 278 | + X = np.column_stack([fit_data.y_work, np.ones_like(fit_data.y_work)]) |
| 279 | + x0, *_ = np.linalg.lstsq(X, fit_data.y_ref, rcond=None) |
| 280 | + else: |
| 281 | + denom = np.dot(fit_data.y_work, fit_data.y_work) or 1.0 |
| 282 | + x0 = np.array([np.dot(fit_data.y_ref, fit_data.y_work) / denom]) |
| 283 | + |
| 284 | + res = least_squares( |
| 285 | + residuals, |
| 286 | + x0=x0, |
| 287 | + loss=cfg.get("robust_loss", "huber"), |
| 288 | + f_scale=float(cfg.get("robust_fscale", 1.0)), |
| 289 | + ) |
| 290 | + |
| 291 | + J = res.jac |
| 292 | + dof = max(1, len(res.fun) - len(res.x)) |
| 293 | + s_sq = np.sum(res.fun**2) / dof |
| 294 | + |
| 295 | + cov = s_sq * np.linalg.pinv(J.T @ J) |
| 296 | + sig_params = np.sqrt(np.clip(np.diag(cov), 0.0, np.inf)) |
| 297 | + |
| 298 | + scale = float(res.x[0]) |
| 299 | + scale_sigma = float(sig_params[0]) |
| 300 | + |
| 301 | + out_key = cfg.get("scale_output_key", "scale_factor") |
| 302 | + work_db[out_key] = BaseData( |
| 303 | + signal=np.array([scale]), |
| 304 | + units="dimensionless", |
| 305 | + uncertainties={"propagate_to_all": np.array([scale_sigma])}, |
| 306 | + rank_of_data=0, |
| 307 | + ) |
| 308 | + |
| 309 | + if fit_background: |
| 310 | + bg_key = cfg.get("background_output_key", "scale_background") |
| 311 | + work_db[bg_key] = BaseData( |
| 312 | + signal=np.array([float(res.x[1])]), |
| 313 | + units=y_ref_bd.units, |
| 314 | + uncertainties={"propagate_to_all": np.array([sig_params[1]])}, |
| 315 | + rank_of_data=0, |
| 316 | + ) |
| 317 | + |
| 318 | + return {work_key: work_db} |
0 commit comments