Skip to content

Commit 30dc0b8

Browse files
authored
Merge pull request #68 from BAMresearch/find_scaling
Find scaling factor module
2 parents 4993339 + d196636 commit 30dc0b8

File tree

5 files changed

+742
-0
lines changed

5 files changed

+742
-0
lines changed

src/modacor/modules/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
# official steps are imported here for ease
1515
from modacor.modules.base_modules.divide import Divide
16+
from modacor.modules.base_modules.find_scale_factor1d import FindScaleFactor1D
1617
from modacor.modules.base_modules.multiply import Multiply
18+
from modacor.modules.base_modules.multiply_databundles import MultiplyDatabundles
1719
from modacor.modules.base_modules.poisson_uncertainties import PoissonUncertainties
1820
from modacor.modules.base_modules.reduce_dimensionality import ReduceDimensionality
1921
from modacor.modules.base_modules.subtract import Subtract
@@ -27,7 +29,9 @@
2729
"Divide",
2830
"IndexPixels",
2931
"IndexedAverager",
32+
"FindScaleFactor1D",
3033
"Multiply",
34+
"MultiplyDatabundles",
3135
"PoissonUncertainties",
3236
"ReduceDimensionality",
3337
"SolidAngleCorrection",
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
# SPDX-License-Identifier: BSD-3-Clause
2+
# /usr/bin/env python3
3+
# -*- coding: utf-8 -*-
4+
5+
from __future__ import annotations
6+
7+
__coding__ = "utf-8"
8+
__authors__ = ["Brian R. Pauw"]
9+
__copyright__ = "Copyright 2025, The MoDaCor team"
10+
__date__ = "12/12/2025"
11+
__status__ = "Development"
12+
13+
__all__ = ["FindScaleFactor1D"]
14+
__version__ = "20251212.2"
15+
16+
from pathlib import Path
17+
from typing import Dict
18+
19+
import numpy as np
20+
from attrs import define
21+
from scipy.interpolate import interp1d
22+
from scipy.optimize import least_squares
23+
24+
from modacor import ureg
25+
from modacor.dataclasses.basedata import BaseData
26+
from modacor.dataclasses.databundle import DataBundle
27+
from modacor.dataclasses.process_step import ProcessStep
28+
from modacor.dataclasses.process_step_describer import ProcessStepDescriber
29+
30+
# -------------------------------------------------------------------------
31+
# Small data containers (attrs, not namedtuple)
32+
# -------------------------------------------------------------------------
33+
34+
35+
@define(slots=True)
36+
class DependentData1D:
37+
y: np.ndarray
38+
sigma: np.ndarray
39+
weights: np.ndarray
40+
41+
42+
@define(slots=True)
43+
class FitData1D:
44+
x: np.ndarray
45+
y_ref: np.ndarray
46+
y_work: np.ndarray
47+
sigma_ref: np.ndarray
48+
sigma_work: np.ndarray
49+
weights: np.ndarray
50+
51+
52+
# -------------------------------------------------------------------------
53+
# Helpers
54+
# -------------------------------------------------------------------------
55+
56+
57+
def _combined_sigma(bd: BaseData) -> np.ndarray:
58+
if not bd.uncertainties:
59+
return np.asarray(1.0)
60+
61+
sig2 = None
62+
for u in bd.uncertainties.values():
63+
arr = np.asarray(u, dtype=float)
64+
sig2 = arr * arr if sig2 is None else sig2 + arr * arr
65+
return np.sqrt(sig2)
66+
67+
68+
def _extract_dependent(bd: BaseData) -> DependentData1D:
69+
if bd.rank_of_data != 1:
70+
raise ValueError("Dependent BaseData must be rank-1.")
71+
72+
y = np.asarray(bd.signal, dtype=float).squeeze()
73+
if y.ndim != 1:
74+
raise ValueError("Dependent signal must be 1D.")
75+
76+
sigma = np.asarray(_combined_sigma(bd), dtype=float)
77+
weights = np.asarray(bd.weights, dtype=float)
78+
79+
if sigma.size == 1:
80+
sigma = np.full_like(y, float(sigma))
81+
else:
82+
sigma = sigma.squeeze()
83+
84+
if weights.size == 1:
85+
weights = np.full_like(y, float(weights))
86+
else:
87+
weights = weights.squeeze()
88+
89+
if sigma.shape != y.shape or weights.shape != y.shape:
90+
raise ValueError("Uncertainties and weights must match dependent signal shape.")
91+
92+
sigma = np.where(sigma <= 0.0, np.nan, sigma)
93+
94+
return DependentData1D(y=y, sigma=sigma, weights=weights)
95+
96+
97+
def _overlap_range(x1: np.ndarray, x2: np.ndarray) -> tuple[float, float]:
98+
return float(max(np.nanmin(x1), np.nanmin(x2))), float(min(np.nanmax(x1), np.nanmax(x2)))
99+
100+
101+
def _prepare_fit_data(
102+
*,
103+
x_work: np.ndarray,
104+
dep_work: DependentData1D,
105+
x_ref: np.ndarray,
106+
dep_ref: DependentData1D,
107+
require_overlap: bool,
108+
interpolation_kind: str,
109+
fit_min: float,
110+
fit_max: float,
111+
use_weights: bool,
112+
) -> FitData1D:
113+
ov_min, ov_max = _overlap_range(x_ref, x_work)
114+
if require_overlap and not (ov_min < ov_max):
115+
raise ValueError("No overlap between working and reference x-axes.")
116+
117+
lo = max(fit_min, ov_min) if require_overlap else fit_min
118+
hi = min(fit_max, ov_max) if require_overlap else fit_max
119+
if not lo < hi:
120+
raise ValueError("Empty fit range after overlap constraints.")
121+
122+
mask = (x_ref >= lo) & (x_ref <= hi)
123+
if np.count_nonzero(mask) < 2:
124+
raise ValueError("Not enough points in fit window.")
125+
126+
x_fit = x_ref[mask]
127+
y_ref = dep_ref.y[mask]
128+
sigma_ref = dep_ref.sigma[mask]
129+
weights_ref = dep_ref.weights[mask]
130+
131+
# sort working data
132+
order = np.argsort(x_work)
133+
x_work = x_work[order]
134+
y_work = dep_work.y[order]
135+
sigma_work = dep_work.sigma[order]
136+
weights_work = dep_work.weights[order]
137+
138+
bounds_error = require_overlap
139+
fill_value = None if bounds_error else "extrapolate"
140+
141+
interp_y = interp1d(
142+
x_work, y_work, kind=interpolation_kind, bounds_error=bounds_error, fill_value=fill_value, assume_sorted=True
143+
)
144+
interp_sigma = interp1d(
145+
x_work, sigma_work, kind="linear", bounds_error=bounds_error, fill_value=fill_value, assume_sorted=True
146+
)
147+
interp_w = interp1d(
148+
x_work, weights_work, kind="linear", bounds_error=bounds_error, fill_value=fill_value, assume_sorted=True
149+
)
150+
151+
y_work_i = interp_y(x_fit)
152+
sigma_work_i = interp_sigma(x_fit)
153+
weights_work_i = interp_w(x_fit)
154+
155+
weights = (weights_ref * weights_work_i) if use_weights else np.ones_like(y_ref)
156+
157+
valid = (
158+
np.isfinite(y_ref)
159+
& np.isfinite(y_work_i)
160+
& np.isfinite(sigma_ref)
161+
& (sigma_ref > 0)
162+
& np.isfinite(sigma_work_i)
163+
& (sigma_work_i >= 0)
164+
& np.isfinite(weights)
165+
& (weights > 0)
166+
)
167+
168+
if np.count_nonzero(valid) < 2:
169+
raise ValueError("Not enough valid points after masking.")
170+
171+
return FitData1D(
172+
x=x_fit[valid],
173+
y_ref=y_ref[valid],
174+
y_work=y_work_i[valid],
175+
sigma_ref=sigma_ref[valid],
176+
sigma_work=sigma_work_i[valid],
177+
weights=weights[valid],
178+
)
179+
180+
181+
# -------------------------------------------------------------------------
182+
# Main ProcessStep
183+
# -------------------------------------------------------------------------
184+
185+
186+
class FindScaleFactor1D(ProcessStep):
187+
documentation = ProcessStepDescriber(
188+
calling_name="Scale 1D curve to reference (compute-only)",
189+
calling_id="FindScaleFactor1D",
190+
calling_module_path=Path(__file__),
191+
calling_version=__version__,
192+
required_data_keys=["signal"],
193+
modifies={
194+
"scale_factor": ["signal", "uncertainties", "units"],
195+
"scale_background": ["signal", "uncertainties", "units"],
196+
},
197+
calling_arguments={
198+
"signal_key": "signal",
199+
"independent_axis_key": "Q",
200+
"scale_output_key": "scale_factor",
201+
"background_output_key": "scale_background",
202+
"fit_background": False,
203+
"fit_min_val": None,
204+
"fit_max_val": None,
205+
"fit_val_units": None,
206+
"require_overlap": True,
207+
"interpolation_kind": "linear",
208+
"robust_loss": "huber",
209+
"robust_fscale": 1.0,
210+
"use_basedata_weights": True,
211+
},
212+
step_keywords=["scale", "calibration", "1D"],
213+
step_doc="Compute scale factor between two 1D curves using robust least squares.",
214+
)
215+
216+
def calculate(self) -> Dict[str, DataBundle]:
217+
cfg = self.configuration
218+
work_key, ref_key = cfg["with_processing_keys"]
219+
220+
sig_key = cfg.get("signal_key", "signal")
221+
axis_key = cfg.get("independent_axis_key", "Q")
222+
223+
work_db = self.processing_data[work_key]
224+
ref_db = self.processing_data[ref_key]
225+
226+
y_work_bd = work_db[sig_key].copy(with_axes=True)
227+
y_ref_bd = ref_db[sig_key].copy(with_axes=True)
228+
229+
x_work_bd = work_db[axis_key].copy(with_axes=False)
230+
x_ref_bd = ref_db[axis_key].copy(with_axes=False)
231+
232+
if x_work_bd.units != x_ref_bd.units:
233+
x_work_bd.to_units(x_ref_bd.units)
234+
235+
x_work = np.asarray(x_work_bd.signal, dtype=float).squeeze()
236+
x_ref = np.asarray(x_ref_bd.signal, dtype=float).squeeze()
237+
238+
dep_work = _extract_dependent(y_work_bd)
239+
dep_ref = _extract_dependent(y_ref_bd)
240+
241+
fit_min = cfg.get("fit_min_val")
242+
fit_max = cfg.get("fit_max_val")
243+
244+
fit_units = cfg.get("fit_val_units") or x_ref_bd.units
245+
if fit_min is not None:
246+
fit_min = ureg.Quantity(fit_min, fit_units).to(x_ref_bd.units).magnitude
247+
else:
248+
fit_min = np.nanmin(x_ref)
249+
250+
if fit_max is not None:
251+
fit_max = ureg.Quantity(fit_max, fit_units).to(x_ref_bd.units).magnitude
252+
else:
253+
fit_max = np.nanmax(x_ref)
254+
255+
fit_data = _prepare_fit_data(
256+
x_work=x_work,
257+
dep_work=dep_work,
258+
x_ref=x_ref,
259+
dep_ref=dep_ref,
260+
require_overlap=cfg.get("require_overlap", True),
261+
interpolation_kind=cfg.get("interpolation_kind", "linear"),
262+
fit_min=float(fit_min),
263+
fit_max=float(fit_max),
264+
use_weights=cfg.get("use_basedata_weights", True),
265+
)
266+
267+
fit_background = bool(cfg.get("fit_background", False))
268+
269+
def residuals(p: np.ndarray) -> np.ndarray:
270+
scale = p[0]
271+
background = p[1] if fit_background else 0.0
272+
model = scale * fit_data.y_work + background
273+
sigma = np.sqrt(fit_data.sigma_ref**2 + (scale * fit_data.sigma_work) ** 2)
274+
r = (fit_data.y_ref - model) / sigma
275+
return np.sqrt(fit_data.weights) * r
276+
277+
if fit_background:
278+
X = np.column_stack([fit_data.y_work, np.ones_like(fit_data.y_work)])
279+
x0, *_ = np.linalg.lstsq(X, fit_data.y_ref, rcond=None)
280+
else:
281+
denom = np.dot(fit_data.y_work, fit_data.y_work) or 1.0
282+
x0 = np.array([np.dot(fit_data.y_ref, fit_data.y_work) / denom])
283+
284+
res = least_squares(
285+
residuals,
286+
x0=x0,
287+
loss=cfg.get("robust_loss", "huber"),
288+
f_scale=float(cfg.get("robust_fscale", 1.0)),
289+
)
290+
291+
J = res.jac
292+
dof = max(1, len(res.fun) - len(res.x))
293+
s_sq = np.sum(res.fun**2) / dof
294+
295+
cov = s_sq * np.linalg.pinv(J.T @ J)
296+
sig_params = np.sqrt(np.clip(np.diag(cov), 0.0, np.inf))
297+
298+
scale = float(res.x[0])
299+
scale_sigma = float(sig_params[0])
300+
301+
out_key = cfg.get("scale_output_key", "scale_factor")
302+
work_db[out_key] = BaseData(
303+
signal=np.array([scale]),
304+
units="dimensionless",
305+
uncertainties={"propagate_to_all": np.array([scale_sigma])},
306+
rank_of_data=0,
307+
)
308+
309+
if fit_background:
310+
bg_key = cfg.get("background_output_key", "scale_background")
311+
work_db[bg_key] = BaseData(
312+
signal=np.array([float(res.x[1])]),
313+
units=y_ref_bd.units,
314+
uncertainties={"propagate_to_all": np.array([sig_params[1]])},
315+
rank_of_data=0,
316+
)
317+
318+
return {work_key: work_db}

0 commit comments

Comments
 (0)