1+ import numpy as np
2+ import dask .array as da
3+ from scipy .optimize import least_squares
4+ from sklearn .cluster import KMeans
5+ import inspect
6+
7+ class SidpyFitterRefactor :
8+ """
9+ A parallelized fitter for sidpy.Datasets that supports K-Means-based
10+ initial guesses for improved convergence on large datasets.
11+
12+ Attributes
13+ ----------
14+ dataset : sidpy.Dataset
15+ The original sidpy dataset containing data and metadata.
16+ dask_data : dask.array.Array
17+ The underlying dask array used for parallel computation.
18+ model_func : callable
19+ The function to fit. Expected signature: f(x_axis, *params).
20+ guess_func : callable
21+ The function to generate initial guesses. Expected signature: f(x_axis, y_data).
22+ metadata : dict
23+ A dictionary containing fit parameters, model source code, and configuration.
24+ """
25+
26+ def __init__ (self , dataset , model_function , guess_function , ind_dims = (2 ,)):
27+ """
28+ Initializes the SidpyFitterKMeans.
29+
30+ Inputs
31+ ----------
32+ dataset : sidpy.Dataset
33+ Dataset to be fitted.
34+ model_function : callable
35+ The model function to use for fitting.
36+ guess_function : callable
37+ The function to generate initial parameters for the model.
38+ ind_dims : int or tuple of int, optional
39+ The indices of the dimensions to fit over. Default is (2,).
40+ #TODO: Change the default to be over the existing spectral dimension
41+ """
42+ import sidpy
43+ if not isinstance (dataset , sidpy .Dataset ):
44+ raise TypeError ("Dataset must be a sidpy.Dataset object." )
45+
46+ self .dataset = dataset
47+ self .dask_data = dataset
48+ self .model_func = model_function
49+ self .guess_func = guess_function
50+
51+ self .ndim = self .dataset .ndim
52+ self .ind_dims = tuple (ind_dims ) if isinstance (ind_dims , int ) else ind_dims
53+ self .spat_dims = [d for d in range (self .ndim ) if d not in self .ind_dims ]
54+
55+ # Standardize x_axis (coordinate values)
56+ self .x_axis = np .array ([self .dataset ._axes [d ].values for d in self .ind_dims ]).squeeze ()
57+
58+ self .is_complex = np .iscomplexobj (self .dataset )
59+ self .num_params = None
60+
61+ # --- Reproducibility Metadata ---
62+ self .metadata = {
63+ "fit_parameters" : {
64+ "ind_dims" : self .ind_dims ,
65+ "is_complex" : self .is_complex ,
66+ "use_kmeans" : False , # Updated during do_fit
67+ "n_clusters" : None # Updated during do_fit
68+ },
69+ "source_code" : {
70+ "model_function" : self ._get_source (model_function ),
71+ "guess_function" : self ._get_source (guess_function )
72+ },
73+ "dataset_info" : {
74+ "original_name" : self .dataset .name ,
75+ "original_shape" : self .dataset .shape
76+ }
77+ }
78+
79+ def _get_source (self , func ):
80+ """Extracts source code from a function for metadata storage."""
81+ try :
82+ return inspect .getsource (func )
83+ except (TypeError , OSError ):
84+ return "Source code not available (function might be defined in a shell or compiled)."
85+
86+ def setup_calc (self , chunks = 'auto' ):
87+ """
88+ Prepares the calculation by rechunking and determining the parameter count.
89+
90+ Parameters
91+ ----------
92+ chunks : str or tuple, optional
93+ The chunk size for the dask array. Default is 'auto'.
94+ """
95+ if chunks :
96+ self .dask_data = self .dask_data .rechunk (chunks )
97+
98+ s_slice = [0 ] * self .ndim
99+ for d in self .ind_dims :
100+ s_slice [d ] = slice (None )
101+
102+ sample_y = np .array (self .dask_data [tuple (s_slice )]).ravel ()
103+ sample_guess = np .asarray (self .guess_func (self .x_axis , sample_y )).ravel ()
104+ self .num_params = len (sample_guess )
105+
106+ self .metadata ["num_params" ] = self .num_params
107+ print (f"Setup Complete. Params: { self .num_params } | Spatial Dims: { self .spat_dims } " )
108+
109+ def _prepare_block (self , block , ind_dims ):
110+ """Internal helper to reshape dask blocks into (Pixels, Spectrum)."""
111+ n_ind = len (ind_dims )
112+ dest = tuple (range (block .ndim - n_ind , block .ndim ))
113+ data = np .moveaxis (block , ind_dims , dest )
114+ spatial_shape = data .shape [:- n_ind ]
115+ flat_data = data .reshape (- 1 , np .prod (data .shape [- n_ind :]))
116+ return flat_data , spatial_shape
117+
118+ def _fit_logic (self , y_vec , x_in , initial_guess ):
119+ """
120+ Core optimization logic for a single pixel.
121+ """
122+ y_vec = np .squeeze (np .asarray (y_vec ))
123+ initial_guess = np .asarray (initial_guess ).ravel ()
124+
125+ if self .is_complex :
126+ y_input = np .hstack ([y_vec .real , y_vec .imag ])
127+ def residuals (p , x , y_s ):
128+ fit = np .squeeze (self .model_func (x , * p ))
129+ if fit .size != y_s .size :
130+ fit = np .hstack ([fit .real , fit .imag ])
131+ return y_s - fit
132+ res = least_squares (residuals , initial_guess , args = (x_in , y_input ))
133+ else :
134+ def residuals (p , x , y ):
135+ fit = np .ravel (self .model_func (x , * p ))
136+ return y - fit
137+ res = least_squares (residuals , initial_guess , args = (x_in , y_vec ))
138+ return res .x
139+
140+ def do_kmeans_guess (self , n_clusters = 10 ):
141+ """
142+ Performs K-Means clustering to find representative spectra for prior fitting.
143+
144+ Parameters
145+ ----------
146+ n_clusters : int, optional
147+ Number of clusters to use for K-Means. Default is 10.
148+
149+ Returns
150+ -------
151+ dask.array.Array
152+ A dask array containing the initial guesses for every pixel.
153+ """
154+ print (f"Starting K-Means Guess with { n_clusters } clusters..." )
155+
156+ # Cast to base Dask Array to bypass sidpy overrides
157+ pure_dask = da .Array (self .dataset .dask , self .dataset .name ,
158+ self .dataset .chunks , self .dataset .dtype )
159+
160+ n_spectral = np .prod ([self .dataset .shape [d ] for d in self .ind_dims ])
161+ total_pixels = self .dataset .size // n_spectral
162+
163+ data_move = da .moveaxis (pure_dask , self .ind_dims , - 1 )
164+ flat_data = data_move .reshape ((int (total_pixels ), int (n_spectral ))).compute ()
165+
166+ clustering_data = np .abs (flat_data ) if self .is_complex else flat_data
167+ denom = (clustering_data .max (axis = 1 , keepdims = True ) -
168+ clustering_data .min (axis = 1 , keepdims = True ) + 1e-12 )
169+ norm_data = (clustering_data - clustering_data .min (axis = 1 , keepdims = True )) / denom
170+
171+ km = KMeans (n_clusters = n_clusters , n_init = 10 , random_state = 42 )
172+ labels = km .fit_predict (norm_data )
173+
174+ print ("Fitting cluster means..." )
175+ priors_per_cluster = np .zeros ((n_clusters , self .num_params ))
176+ for i in range (n_clusters ):
177+ mask = (labels == i )
178+ if not np .any (mask ): continue
179+ mean_spec = flat_data [mask ].mean (axis = 0 )
180+ init_p = self .guess_func (self .x_axis , mean_spec )
181+ priors_per_cluster [i ] = self ._fit_logic (mean_spec , self .x_axis , init_p )
182+
183+ full_prior_flat = priors_per_cluster [labels ]
184+ spatial_shape = [self .dataset .shape [d ] for d in self .spat_dims ]
185+ full_prior_map = full_prior_flat .reshape (spatial_shape + [self .num_params ])
186+
187+ return da .from_array (full_prior_map , chunks = 'auto' )
188+
189+ def do_guess (self ):
190+ """Parallelized guess logic across all pixels."""
191+ def guess_worker (block , x_in , ind_dims , num_params ):
192+ block = np .asarray (block )
193+ flat_data , spat_shape = self ._prepare_block (block , ind_dims )
194+ out_flat = np .zeros ((flat_data .shape [0 ], num_params ))
195+ for i in range (flat_data .shape [0 ]):
196+ res = self .guess_func (x_in , flat_data [i ])
197+ out_flat [i ] = np .asarray (res ).ravel ()
198+ return out_flat .reshape (spat_shape + (num_params ,))
199+
200+ return self .dask_data .map_blocks (
201+ guess_worker , self .x_axis , self .ind_dims , self .num_params ,
202+ dtype = np .float32 , drop_axis = self .ind_dims , new_axis = [self .ndim ]
203+ )
204+
205+ def do_fit (self , guesses = None , use_kmeans = False , n_clusters = 10 ):
206+ """
207+ Executes the parallel fit across the dataset.
208+
209+ Parameters
210+ ----------
211+ guesses : dask.array.Array, optional
212+ Initial guesses. If None, generated automatically.
213+ use_kmeans : bool, optional
214+ Whether to use K-means priors. Default is False.
215+ n_clusters : int, optional
216+ Number of clusters if use_kmeans is True. Default is 10.
217+
218+ Returns
219+ -------
220+ dask.array.Array
221+ Dask array containing the optimized fit parameters.
222+ """
223+ # Update metadata for the current run
224+ self .metadata ["fit_parameters" ]["use_kmeans" ] = use_kmeans
225+ self .metadata ["fit_parameters" ]["n_clusters" ] = n_clusters if use_kmeans else None
226+
227+ if guesses is None :
228+ guesses = self .do_kmeans_guess (n_clusters ) if use_kmeans else self .do_guess ()
229+
230+ def fit_worker (data_block , guess_block , x_in , ind_dims , num_params ):
231+ data_block , guess_block = np .asarray (data_block ), np .asarray (guess_block )
232+ flat_data , spat_shape = self ._prepare_block (data_block , ind_dims )
233+ flat_guess = guess_block .reshape (- 1 , guess_block .shape [- 1 ])
234+
235+ out_flat = np .zeros ((flat_data .shape [0 ], num_params ))
236+ for i in range (flat_data .shape [0 ]):
237+ if flat_data [i ].size == 0 : continue
238+ out_flat [i ] = self ._fit_logic (flat_data [i ], x_in , flat_guess [i ])
239+ return out_flat .reshape (spat_shape + (num_params ,))
240+
241+ data_ind = tuple (range (self .ndim ))
242+ guess_ind = tuple (self .spat_dims + [self .ndim ])
243+
244+ return da .blockwise (
245+ fit_worker , guess_ind ,
246+ self .dask_data , data_ind ,
247+ guesses , guess_ind ,
248+ self .x_axis , None ,
249+ self .ind_dims , None ,
250+ self .num_params , None ,
251+ dtype = np .float32 , align_arrays = True , concatenate = True
252+ )
0 commit comments