Skip to content

Commit d8f8f42

Browse files
committed
updated fitter refactor with bounds
1 parent 3718ca6 commit d8f8f42

File tree

1 file changed

+66
-5
lines changed

1 file changed

+66
-5
lines changed

sidpy/proc/fitter_refactor.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class SidpyFitterRefactor:
2929
A dictionary containing fit parameters, model source code, and configuration.
3030
"""
3131

32-
def __init__(self, dataset, model_function, guess_function, ind_dims=None, num_params=None):
32+
def __init__(self, dataset, model_function, guess_function, ind_dims=None, num_params=None,
33+
lower_bounds=None, upper_bounds=None):
3334
"""
3435
Initializes the SidpyFitterKMeans.
3536
@@ -45,6 +46,15 @@ def __init__(self, dataset, model_function, guess_function, ind_dims=None, num_p
4546
The indices of the dimensions to fit over. Default is whatever are the spectral dimensions
4647
num_params: int, optional but required in case of 2D or higher fitting
4748
The number of parameters the fitting function expects.
49+
lower_bounds : None, float, or array-like, optional
50+
Lower bounds for the fit parameters. Can be:
51+
- None (default): no lower bound (-inf) on any parameter.
52+
- scalar float: the same lower bound applied to every parameter.
53+
- array-like of length num_params: per-parameter lower bounds.
54+
Must satisfy lower_bounds <= upper_bounds element-wise.
55+
upper_bounds : None, float, or array-like, optional
56+
Upper bounds for the fit parameters. Same rules as lower_bounds.
57+
Must satisfy upper_bounds >= lower_bounds element-wise.
4858
"""
4959
import sidpy
5060
if not isinstance(dataset, sidpy.Dataset):
@@ -54,7 +64,11 @@ def __init__(self, dataset, model_function, guess_function, ind_dims=None, num_p
5464
self.dask_data = dataset
5565
self.model_func = model_function
5666
self.guess_func = guess_function
57-
67+
68+
# Store bounds early so _fit_logic and do_kmeans_guess always see them.
69+
self.lower_bounds = lower_bounds
70+
self.upper_bounds = upper_bounds
71+
5872
self.ndim = self.dataset.ndim
5973
self.spectral_dims = tuple(self.dataset.get_spectral_dims())
6074

@@ -73,7 +87,13 @@ def __init__(self, dataset, model_function, guess_function, ind_dims=None, num_p
7387
"ind_dims": self.ind_dims,
7488
"is_complex": self.is_complex,
7589
"use_kmeans": False, # Updated during do_fit
76-
"n_clusters": None # Updated during do_fit
90+
"n_clusters": None, # Updated during do_fit
91+
"lower_bounds": (lower_bounds.tolist()
92+
if isinstance(lower_bounds, np.ndarray)
93+
else lower_bounds),
94+
"upper_bounds": (upper_bounds.tolist()
95+
if isinstance(upper_bounds, np.ndarray)
96+
else upper_bounds),
7797
},
7898
"source_code": {
7999
"model_function": self._get_source(model_function),
@@ -131,7 +151,41 @@ def _fit_logic(self, y_vec, x_in, initial_guess, loss='linear', f_scale=1.0, ret
131151
"""
132152
y_vec = np.squeeze(np.asarray(y_vec))
133153
initial_guess = np.asarray(initial_guess).ravel()
134-
154+
n = initial_guess.size
155+
156+
# --- Resolve bounds --------------------------------------------
157+
def _resolve_bound(b, n_params, fill_value):
158+
"""Return a float64 array of length n_params from a flexible input."""
159+
if b is None:
160+
return np.full(n_params, fill_value, dtype=np.float64)
161+
arr = np.asarray(b, dtype=np.float64).ravel()
162+
if arr.size == 1:
163+
return np.full(n_params, arr[0], dtype=np.float64)
164+
if arr.size != n_params:
165+
raise ValueError(
166+
f"Bound array length ({arr.size}) does not match "
167+
f"num_params ({n_params}). Provide a scalar or an "
168+
f"array-like of length {n_params}."
169+
)
170+
return arr
171+
172+
lb = _resolve_bound(self.lower_bounds, n, -np.inf)
173+
ub = _resolve_bound(self.upper_bounds, n, np.inf)
174+
175+
if np.any(lb > ub):
176+
raise ValueError(
177+
"lower_bounds must be <= upper_bounds for every parameter. "
178+
f"Violations at indices: {np.where(lb > ub)[0].tolist()}"
179+
)
180+
181+
# Clip guess into bounds so least_squares doesn't raise immediately.
182+
initial_guess = np.clip(initial_guess, lb, ub)
183+
184+
# 'lm' only supports linear loss and no bounds; use 'trf' otherwise.
185+
has_finite_bounds = np.any(np.isfinite(lb)) or np.any(np.isfinite(ub))
186+
method = 'trf' if (has_finite_bounds or loss != 'linear') else 'lm'
187+
# ---------------------------------------------------------------
188+
135189
# Prepare data for least_squares (handle complex)
136190
if self.is_complex:
137191
y_input = np.hstack([y_vec.real, y_vec.imag])
@@ -148,6 +202,7 @@ def residuals(p, x, y):
148202

149203
# Run Fit
150204
res = least_squares(residuals, initial_guess, args=(x_in, y_input),
205+
bounds=(lb, ub), method=method,
151206
loss=loss, f_scale=f_scale)
152207

153208
if not return_cov:
@@ -386,7 +441,13 @@ def do_fit(self, guesses=None, use_kmeans=False, n_clusters=10,
386441
"n_clusters": n_clusters if use_kmeans else None,
387442
"loss": loss,
388443
"f_scale": f_scale,
389-
"return_cov": return_cov
444+
"return_cov": return_cov,
445+
"lower_bounds": (self.lower_bounds.tolist()
446+
if isinstance(self.lower_bounds, np.ndarray)
447+
else self.lower_bounds),
448+
"upper_bounds": (self.upper_bounds.tolist()
449+
if isinstance(self.upper_bounds, np.ndarray)
450+
else self.upper_bounds),
390451
})
391452

392453
if guesses is None:

0 commit comments

Comments
 (0)