Skip to content

Commit 05e9b2f

Browse files
committed
fitter refactored
1 parent a289a80 commit 05e9b2f

File tree

3 files changed

+635
-1
lines changed

3 files changed

+635
-1
lines changed

sidpy/proc/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
"""
44

55
from . import fitter
6+
from . import fitter_refactor
67

7-
__all__ = ['fitter']
8+
__all__ = ['fitter', 'fitter_refactor']

sidpy/proc/fitter_refactor.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import numpy as np
2+
import dask.array as da
3+
from scipy.optimize import least_squares
4+
from sklearn.cluster import KMeans
5+
import inspect
6+
7+
class SidpyFitterRefactor:
8+
"""
9+
A parallelized fitter for sidpy.Datasets that supports K-Means-based
10+
initial guesses for improved convergence on large datasets.
11+
12+
Attributes
13+
----------
14+
dataset : sidpy.Dataset
15+
The original sidpy dataset containing data and metadata.
16+
dask_data : dask.array.Array
17+
The underlying dask array used for parallel computation.
18+
model_func : callable
19+
The function to fit. Expected signature: f(x_axis, *params).
20+
guess_func : callable
21+
The function to generate initial guesses. Expected signature: f(x_axis, y_data).
22+
metadata : dict
23+
A dictionary containing fit parameters, model source code, and configuration.
24+
"""
25+
26+
def __init__(self, dataset, model_function, guess_function, ind_dims=(2,)):
27+
"""
28+
Initializes the SidpyFitterKMeans.
29+
30+
Inputs
31+
----------
32+
dataset : sidpy.Dataset
33+
Dataset to be fitted.
34+
model_function : callable
35+
The model function to use for fitting.
36+
guess_function : callable
37+
The function to generate initial parameters for the model.
38+
ind_dims : int or tuple of int, optional
39+
The indices of the dimensions to fit over. Default is (2,).
40+
#TODO: Change the default to be over the existing spectral dimension
41+
"""
42+
import sidpy
43+
if not isinstance(dataset, sidpy.Dataset):
44+
raise TypeError("Dataset must be a sidpy.Dataset object.")
45+
46+
self.dataset = dataset
47+
self.dask_data = dataset
48+
self.model_func = model_function
49+
self.guess_func = guess_function
50+
51+
self.ndim = self.dataset.ndim
52+
self.ind_dims = tuple(ind_dims) if isinstance(ind_dims, int) else ind_dims
53+
self.spat_dims = [d for d in range(self.ndim) if d not in self.ind_dims]
54+
55+
# Standardize x_axis (coordinate values)
56+
self.x_axis = np.array([self.dataset._axes[d].values for d in self.ind_dims]).squeeze()
57+
58+
self.is_complex = np.iscomplexobj(self.dataset)
59+
self.num_params = None
60+
61+
# --- Reproducibility Metadata ---
62+
self.metadata = {
63+
"fit_parameters": {
64+
"ind_dims": self.ind_dims,
65+
"is_complex": self.is_complex,
66+
"use_kmeans": False, # Updated during do_fit
67+
"n_clusters": None # Updated during do_fit
68+
},
69+
"source_code": {
70+
"model_function": self._get_source(model_function),
71+
"guess_function": self._get_source(guess_function)
72+
},
73+
"dataset_info": {
74+
"original_name": self.dataset.name,
75+
"original_shape": self.dataset.shape
76+
}
77+
}
78+
79+
def _get_source(self, func):
80+
"""Extracts source code from a function for metadata storage."""
81+
try:
82+
return inspect.getsource(func)
83+
except (TypeError, OSError):
84+
return "Source code not available (function might be defined in a shell or compiled)."
85+
86+
def setup_calc(self, chunks='auto'):
87+
"""
88+
Prepares the calculation by rechunking and determining the parameter count.
89+
90+
Parameters
91+
----------
92+
chunks : str or tuple, optional
93+
The chunk size for the dask array. Default is 'auto'.
94+
"""
95+
if chunks:
96+
self.dask_data = self.dask_data.rechunk(chunks)
97+
98+
s_slice = [0] * self.ndim
99+
for d in self.ind_dims:
100+
s_slice[d] = slice(None)
101+
102+
sample_y = np.array(self.dask_data[tuple(s_slice)]).ravel()
103+
sample_guess = np.asarray(self.guess_func(self.x_axis, sample_y)).ravel()
104+
self.num_params = len(sample_guess)
105+
106+
self.metadata["num_params"] = self.num_params
107+
print(f"Setup Complete. Params: {self.num_params} | Spatial Dims: {self.spat_dims}")
108+
109+
def _prepare_block(self, block, ind_dims):
110+
"""Internal helper to reshape dask blocks into (Pixels, Spectrum)."""
111+
n_ind = len(ind_dims)
112+
dest = tuple(range(block.ndim - n_ind, block.ndim))
113+
data = np.moveaxis(block, ind_dims, dest)
114+
spatial_shape = data.shape[:-n_ind]
115+
flat_data = data.reshape(-1, np.prod(data.shape[-n_ind:]))
116+
return flat_data, spatial_shape
117+
118+
def _fit_logic(self, y_vec, x_in, initial_guess):
119+
"""
120+
Core optimization logic for a single pixel.
121+
"""
122+
y_vec = np.squeeze(np.asarray(y_vec))
123+
initial_guess = np.asarray(initial_guess).ravel()
124+
125+
if self.is_complex:
126+
y_input = np.hstack([y_vec.real, y_vec.imag])
127+
def residuals(p, x, y_s):
128+
fit = np.squeeze(self.model_func(x, *p))
129+
if fit.size != y_s.size:
130+
fit = np.hstack([fit.real, fit.imag])
131+
return y_s - fit
132+
res = least_squares(residuals, initial_guess, args=(x_in, y_input))
133+
else:
134+
def residuals(p, x, y):
135+
fit = np.ravel(self.model_func(x, *p))
136+
return y - fit
137+
res = least_squares(residuals, initial_guess, args=(x_in, y_vec))
138+
return res.x
139+
140+
def do_kmeans_guess(self, n_clusters=10):
141+
"""
142+
Performs K-Means clustering to find representative spectra for prior fitting.
143+
144+
Parameters
145+
----------
146+
n_clusters : int, optional
147+
Number of clusters to use for K-Means. Default is 10.
148+
149+
Returns
150+
-------
151+
dask.array.Array
152+
A dask array containing the initial guesses for every pixel.
153+
"""
154+
print(f"Starting K-Means Guess with {n_clusters} clusters...")
155+
156+
# Cast to base Dask Array to bypass sidpy overrides
157+
pure_dask = da.Array(self.dataset.dask, self.dataset.name,
158+
self.dataset.chunks, self.dataset.dtype)
159+
160+
n_spectral = np.prod([self.dataset.shape[d] for d in self.ind_dims])
161+
total_pixels = self.dataset.size // n_spectral
162+
163+
data_move = da.moveaxis(pure_dask, self.ind_dims, -1)
164+
flat_data = data_move.reshape((int(total_pixels), int(n_spectral))).compute()
165+
166+
clustering_data = np.abs(flat_data) if self.is_complex else flat_data
167+
denom = (clustering_data.max(axis=1, keepdims=True) -
168+
clustering_data.min(axis=1, keepdims=True) + 1e-12)
169+
norm_data = (clustering_data - clustering_data.min(axis=1, keepdims=True)) / denom
170+
171+
km = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
172+
labels = km.fit_predict(norm_data)
173+
174+
print("Fitting cluster means...")
175+
priors_per_cluster = np.zeros((n_clusters, self.num_params))
176+
for i in range(n_clusters):
177+
mask = (labels == i)
178+
if not np.any(mask): continue
179+
mean_spec = flat_data[mask].mean(axis=0)
180+
init_p = self.guess_func(self.x_axis, mean_spec)
181+
priors_per_cluster[i] = self._fit_logic(mean_spec, self.x_axis, init_p)
182+
183+
full_prior_flat = priors_per_cluster[labels]
184+
spatial_shape = [self.dataset.shape[d] for d in self.spat_dims]
185+
full_prior_map = full_prior_flat.reshape(spatial_shape + [self.num_params])
186+
187+
return da.from_array(full_prior_map, chunks='auto')
188+
189+
def do_guess(self):
190+
"""Parallelized guess logic across all pixels."""
191+
def guess_worker(block, x_in, ind_dims, num_params):
192+
block = np.asarray(block)
193+
flat_data, spat_shape = self._prepare_block(block, ind_dims)
194+
out_flat = np.zeros((flat_data.shape[0], num_params))
195+
for i in range(flat_data.shape[0]):
196+
res = self.guess_func(x_in, flat_data[i])
197+
out_flat[i] = np.asarray(res).ravel()
198+
return out_flat.reshape(spat_shape + (num_params,))
199+
200+
return self.dask_data.map_blocks(
201+
guess_worker, self.x_axis, self.ind_dims, self.num_params,
202+
dtype=np.float32, drop_axis=self.ind_dims, new_axis=[self.ndim]
203+
)
204+
205+
def do_fit(self, guesses=None, use_kmeans=False, n_clusters=10):
206+
"""
207+
Executes the parallel fit across the dataset.
208+
209+
Parameters
210+
----------
211+
guesses : dask.array.Array, optional
212+
Initial guesses. If None, generated automatically.
213+
use_kmeans : bool, optional
214+
Whether to use K-means priors. Default is False.
215+
n_clusters : int, optional
216+
Number of clusters if use_kmeans is True. Default is 10.
217+
218+
Returns
219+
-------
220+
dask.array.Array
221+
Dask array containing the optimized fit parameters.
222+
"""
223+
# Update metadata for the current run
224+
self.metadata["fit_parameters"]["use_kmeans"] = use_kmeans
225+
self.metadata["fit_parameters"]["n_clusters"] = n_clusters if use_kmeans else None
226+
227+
if guesses is None:
228+
guesses = self.do_kmeans_guess(n_clusters) if use_kmeans else self.do_guess()
229+
230+
def fit_worker(data_block, guess_block, x_in, ind_dims, num_params):
231+
data_block, guess_block = np.asarray(data_block), np.asarray(guess_block)
232+
flat_data, spat_shape = self._prepare_block(data_block, ind_dims)
233+
flat_guess = guess_block.reshape(-1, guess_block.shape[-1])
234+
235+
out_flat = np.zeros((flat_data.shape[0], num_params))
236+
for i in range(flat_data.shape[0]):
237+
if flat_data[i].size == 0: continue
238+
out_flat[i] = self._fit_logic(flat_data[i], x_in, flat_guess[i])
239+
return out_flat.reshape(spat_shape + (num_params,))
240+
241+
data_ind = tuple(range(self.ndim))
242+
guess_ind = tuple(self.spat_dims + [self.ndim])
243+
244+
return da.blockwise(
245+
fit_worker, guess_ind,
246+
self.dask_data, data_ind,
247+
guesses, guess_ind,
248+
self.x_axis, None,
249+
self.ind_dims, None,
250+
self.num_params, None,
251+
dtype=np.float32, align_arrays=True, concatenate=True
252+
)

0 commit comments

Comments
 (0)