1+ import warnings
2+ from sklearn .decomposition import NMF , PCA , FastICA
3+ from sklearn .mixture import GaussianMixture
4+ import numpy as np
5+ import matplotlib .pyplot as plt
6+
7+
8+ class SpectralUnmixer :
9+ """
10+ Applies various decomposition algorithms to hyperspectral data for
11+ spectral unmixing and component analysis.
12+
13+ Supported methods: 'nmf', 'pca', 'ica', 'gmm'.
14+ """
15+ def __init__ (self , method : str = 'nmf' , n_components : int = 4 , normalize : bool = False , ** kwargs ):
16+ """
17+ Initializes the unmixer.
18+
19+ Args:
20+ method (str): The decomposition method to use.
21+ Options: 'nmf', 'pca', 'ica', 'gmm'.
22+ n_components (int): The number of components to find.
23+ normalize (bool): If True, each spectrum is L1-normalized (sums to 1)
24+ before decomposition. This is highly recommended for NMF.
25+ **kwargs: Additional keyword arguments to pass to the
26+ underlying sklearn model (e.g., max_iter, pca_dims).
27+ """
28+ self .method = method
29+ self .n_components = n_components
30+ self .normalize = normalize
31+ self .kwargs = kwargs
32+
33+ if self .method == 'nmf' :
34+ self .model = NMF (n_components = n_components , ** self .kwargs )
35+ elif self .method == 'pca' :
36+ self .model = PCA (n_components = n_components , ** self .kwargs )
37+ elif self .method == 'ica' :
38+ self .model = FastICA (n_components = n_components , whiten = 'unit-variance' , max_iter = self .kwargs .get ("max_iter" , 200 ))
39+ elif self .method == 'gmm' :
40+ self .model = GaussianMixture (n_components = n_components , ** self .kwargs )
41+ else :
42+ raise ValueError ("Method not recognized. Choose from 'nmf', 'pca', 'ica', 'gmm'." )
43+ self .components_ = None
44+ self .abundance_maps_ = None
45+ self .image_shape_ = None
46+
47+
48+ def fit (self , hspy_data : np .ndarray ):
49+ """
50+ Fits the selected model to a hyperspectral data cube.
51+ """
52+ if hspy_data .ndim != 3 :
53+ raise ValueError ("Input data must be a 3D hyperspectral cube (h, w, e)." )
54+
55+ self .image_shape_ = hspy_data .shape [:2 ]
56+ h , w , e = hspy_data .shape
57+ spectra_matrix = hspy_data .reshape ((h * w , e ))
58+
59+ # Data to be passed to the fitting algorithm
60+ spectra_to_fit = spectra_matrix .copy ()
61+
62+ # Optional per-spectrum L1 normalization
63+ if self .normalize :
64+ print ("Normalizing each spectrum to sum to 1 (L1 norm)..." )
65+ # Store norms for later rescaling of abundances
66+ l1_norms = np .sum (spectra_matrix , axis = 1 , keepdims = True )
67+ # Avoid division by zero for empty spectra (e.g., from outside scan region)
68+ l1_norms [l1_norms == 0 ] = 1
69+ spectra_to_fit = spectra_matrix / l1_norms
70+
71+ print (f"Fitting data with { self .method .upper ()} ..." )
72+
73+ # NMF non-negativity check
74+ if self .method == 'nmf' :
75+ min_val = np .min (spectra_to_fit )
76+ if min_val < 0 :
77+ warnings .warn (f"NMF requires non-negative data. Shifting data by { - min_val :.2f} ." )
78+ spectra_to_fit = spectra_to_fit - min_val
79+
80+ # GMM's PCA+GMM robust workflow
81+ if self .method == 'gmm' :
82+ # PCA dimension selection
83+ pca_param = self .kwargs .get ('pca_dims' , 0.99 ) # Default to 99% variance
84+
85+ print ("Applying PCA for dimensionality reduction before GMM..." )
86+ # First, fit PCA on all components to check variance
87+ pca_full = PCA ()
88+ pca_full .fit (spectra_to_fit )
89+
90+ if isinstance (pca_param , int ):
91+ n_components_pca = pca_param
92+ print (f"Using a fixed number of { n_components_pca } principal components." )
93+ elif isinstance (pca_param , float ) and 0 < pca_param < 1 :
94+ cumulative_variance = np .cumsum (pca_full .explained_variance_ratio_ )
95+ # Find the number of components needed to reach the threshold
96+ n_components_pca = np .searchsorted (cumulative_variance , pca_param ) + 1
97+ explained_var_actual = cumulative_variance [n_components_pca - 1 ]
98+ print (
99+ f"Found { n_components_pca } components that explain { explained_var_actual :.1%} "
100+ f"of the variance (threshold was { pca_param :.1%} )."
101+ )
102+ else :
103+ raise ValueError ("pca_dims' must be an int or a float between 0 and 1." )
104+
105+ # Perform the final PCA transformation with the optimal number of components
106+ pca_final = PCA (n_components = n_components_pca )
107+ projected_data = pca_final .fit_transform (spectra_to_fit )
108+
109+ # Fit GMM on the low-dimensional data
110+ self .model .fit (projected_data )
111+ labels = self .model .predict (projected_data )
112+ abundances_unscaled = self .model .predict_proba (projected_data )
113+ self .components_ = np .array ([
114+ spectra_matrix [labels == i ].mean (axis = 0 )
115+ for i in range (self .n_components )
116+ ])
117+ else : # For NMF, PCA, ICA
118+ abundances_unscaled = self .model .fit_transform (spectra_to_fit )
119+ self .components_ = self .model .components_
120+
121+ # Rescale abundances if data was normalized
122+ if self .normalize :
123+ abundances = abundances_unscaled * l1_norms
124+ else :
125+ abundances = abundances_unscaled
126+
127+ # Reshape abundance maps back to image dimensions
128+ self .abundance_maps_ = abundances .reshape ((h , w , self .n_components ))
129+
130+ print ("Fit complete." )
131+ return self .components_ , self .abundance_maps_
132+
133+ def plot_results (self , ** kwargs ):
134+ if self .components_ is None :
135+ print ("You must run .fit() first." )
136+ return
137+
138+ cmap = 'seismic'
139+ cmap = kwargs .get ("cmap" , cmap )
140+
141+ n_cols = self .n_components
142+ fig , axes = plt .subplots (2 , n_cols , figsize = kwargs .get ("figsize" , (n_cols * 3.5 , 6 )))
143+
144+ for i in range (self .n_components ):
145+ # Plot component spectrum
146+ axes [0 , i ].plot (self .components_ [i , :])
147+ axes [0 , i ].set_title (f'{ self .method .upper ()} Component { i + 1 } ' )
148+ axes [0 , i ].set_xlabel ('Energy Bin' )
149+ if i == 0 :
150+ axes [0 , i ].set_ylabel ('Intensity' )
151+
152+ # Plot abundance map
153+ ax_map = axes [1 , i ]
154+ im = ax_map .imshow (self .abundance_maps_ [..., i ], cmap = cmap )
155+ ax_map .set_title (f'Abundance Map { i + 1 } ' )
156+ ax_map .axis ('off' )
157+ fig .colorbar (im , ax = axes [1 , i ], fraction = 0.046 , pad = 0.04 )
158+
159+ plt .tight_layout ()
160+ plt .show ()
0 commit comments