Skip to content

Commit 034555d

Browse files
author
Rama Vasudevan
committed
new tests
1 parent d8f8f42 commit 034555d

File tree

1 file changed

+144
-1
lines changed

1 file changed

+144
-1
lines changed

tests/proc/test_fitter_refactor.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,4 +227,147 @@ def test_2d_fit_execution(self):
227227
diag = np.diag(cov_matrix)
228228
self.assertTrue(np.all(diag >= 0), "Covariance diagonal elements must be non-negative")
229229

230-
230+
class TestSidpyFitterWithBounds(unittest.TestCase):
231+
232+
def setUp(self):
233+
"""
234+
Synthetic 3x3 spatial x 50-point spectral 1D Gaussian dataset.
235+
Fast, self-contained, no network required.
236+
"""
237+
self.n_x, self.n_y, self.n_spec = 3, 3, 50
238+
x_axis = np.linspace(-10, 10, self.n_spec)
239+
self.true_params = np.array([3.0, 0.0, 2.0, 0.5]) # amp, center, sigma, offset
240+
241+
raw = np.zeros((self.n_x, self.n_y, self.n_spec))
242+
for i in range(self.n_x):
243+
for j in range(self.n_y):
244+
amp, cen, sig, off = self.true_params
245+
cen_ij = cen + 0.5 * (i - 1)
246+
raw[i, j] = (amp * np.exp(-0.5 * ((x_axis - cen_ij) / sig) ** 2)
247+
+ off
248+
+ np.random.default_rng(i * 10 + j).normal(0, 0.05, self.n_spec))
249+
250+
self.dataset = sid.Dataset.from_array(raw, name='Synthetic_1D_Gauss')
251+
self.dataset.set_dimension(0, sid.Dimension(np.arange(self.n_x), 'x',
252+
dimension_type='spatial'))
253+
self.dataset.set_dimension(1, sid.Dimension(np.arange(self.n_y), 'y',
254+
dimension_type='spatial'))
255+
self.dataset.set_dimension(2, sid.Dimension(x_axis, 'spectrum',
256+
dimension_type='spectral'))
257+
258+
def _gaussian(x, amp, cen, sig, off):
259+
return amp * np.exp(-0.5 * ((x - cen) / sig) ** 2) + off
260+
261+
def _gaussian_guess(x, y):
262+
off = np.percentile(y, 10)
263+
amp = float(y.max()) - off
264+
cen = float(x[np.argmax(y)])
265+
sig = (x[-1] - x[0]) / 6.0
266+
return [amp, cen, sig, off]
267+
268+
self.model_func = _gaussian
269+
self.guess_func = _gaussian_guess
270+
271+
def _make_fitter(self, lower_bounds=None, upper_bounds=None):
272+
"""Helper: build and setup a fitter with optional bounds."""
273+
fitter = SidpyFitterRefactor(
274+
self.dataset, self.model_func, self.guess_func,
275+
lower_bounds=lower_bounds, upper_bounds=upper_bounds,
276+
)
277+
fitter.setup_calc()
278+
return fitter
279+
280+
def test_unbounded_unchanged(self):
281+
"""Unbounded fit must return a valid sidpy.Dataset with finite params."""
282+
result, _ = self._make_fitter().do_fit()
283+
self.assertIsInstance(result, sid.Dataset)
284+
self.assertTrue(np.all(np.isfinite(np.array(result))))
285+
286+
def test_scalar_lower_bound(self):
287+
"""Scalar lower_bounds=0 — all returned params must be >= 0."""
288+
result, _ = self._make_fitter(lower_bounds=0.0).do_fit()
289+
params = np.array(result)
290+
self.assertTrue(np.all(params >= -1e-6),
291+
msg=f"Some params violated lower_bound=0: min={params.min()}")
292+
293+
def test_scalar_upper_bound(self):
294+
"""Scalar upper_bounds=1e6 — all returned params must be <= 1e6."""
295+
upper = 1e6
296+
result, _ = self._make_fitter(upper_bounds=upper).do_fit()
297+
params = np.array(result)
298+
self.assertTrue(np.all(params <= upper + 1e-6),
299+
msg=f"Some params violated upper_bound={upper}: max={params.max()}")
300+
301+
def test_per_param_bounds_respected(self):
302+
"""Per-parameter array bounds — each param stays within its own [lb, ub]."""
303+
n = self._make_fitter().num_params
304+
lb = np.zeros(n)
305+
ub = np.full(n, 1e6)
306+
result, _ = self._make_fitter(lower_bounds=lb, upper_bounds=ub).do_fit()
307+
params = np.array(result)
308+
for i in range(n):
309+
p = params[..., i]
310+
self.assertTrue(np.all(p >= lb[i] - 1e-6),
311+
msg=f"Param {i} violated lower bound {lb[i]}: min={p.min()}")
312+
self.assertTrue(np.all(p <= ub[i] + 1e-6),
313+
msg=f"Param {i} violated upper bound {ub[i]}: max={p.max()}")
314+
315+
def test_guess_outside_bounds_no_crash(self):
316+
"""Guess outside bounds must be clipped silently, not crash."""
317+
n = self._make_fitter().num_params
318+
fitter = self._make_fitter(lower_bounds=np.full(n, -1e-10),
319+
upper_bounds=np.full(n, 1e-10))
320+
try:
321+
fitter.do_fit()
322+
except Exception as e:
323+
self.fail(f"do_fit raised unexpectedly with out-of-bounds guess: {e}")
324+
325+
def test_bounds_length_mismatch_raises(self):
326+
"""Bound array with wrong length must raise ValueError."""
327+
n = self._make_fitter().num_params
328+
fitter = self._make_fitter(lower_bounds=np.zeros(n + 3))
329+
with self.assertRaises(ValueError):
330+
fitter.do_fit()
331+
332+
def test_lb_greater_than_ub_raises(self):
333+
"""lower_bounds > upper_bounds must raise ValueError."""
334+
n = self._make_fitter().num_params
335+
fitter = self._make_fitter(lower_bounds=np.full(n, 10.0),
336+
upper_bounds=np.full(n, 1.0))
337+
with self.assertRaises(ValueError):
338+
fitter.do_fit()
339+
340+
def test_bounds_stored_in_metadata(self):
341+
"""Bounds must appear correctly in result metadata."""
342+
n = self._make_fitter().num_params
343+
lb = list(np.zeros(n))
344+
ub = list(np.ones(n) * 1e6)
345+
result, _ = self._make_fitter(lower_bounds=lb, upper_bounds=ub).do_fit()
346+
meta = result.metadata["fit_parameters"]
347+
self.assertIn("lower_bounds", meta)
348+
self.assertIn("upper_bounds", meta)
349+
self.assertEqual(meta["lower_bounds"], lb)
350+
self.assertEqual(meta["upper_bounds"], ub)
351+
352+
def test_none_bounds_metadata_is_none(self):
353+
"""When no bounds are passed, metadata entries must be None."""
354+
result, _ = self._make_fitter().do_fit()
355+
meta = result.metadata["fit_parameters"]
356+
self.assertIsNone(meta.get("lower_bounds"))
357+
self.assertIsNone(meta.get("upper_bounds"))
358+
359+
def test_bounds_with_nonlinear_loss(self):
360+
"""Bounds + non-linear loss must not raise (both require method='trf')."""
361+
fitter = self._make_fitter(lower_bounds=0.0)
362+
try:
363+
result, _ = fitter.do_fit(loss='soft_l1')
364+
except Exception as e:
365+
self.fail(f"do_fit raised with bounds + non-linear loss: {e}")
366+
self.assertIsInstance(result, sid.Dataset)
367+
368+
def test_bounds_with_return_cov(self):
369+
"""Bounds must be compatible with return_cov=True; covariance shape check."""
370+
n = self._make_fitter().num_params
371+
params, cov = self._make_fitter(lower_bounds=0.0).do_fit(return_cov=True)
372+
self.assertEqual(cov.shape[-2:], (n, n),
373+
msg=f"Covariance shape {cov.shape} does not end in ({n},{n})")

0 commit comments

Comments
 (0)