Skip to content

Commit 4be7234

Browse files
authored
Merge pull request #94 from ziatdinovmax/master
Add basic spectral analyzer
2 parents b510e83 + b3c050a commit 4be7234

File tree

4 files changed

+319
-326
lines changed

4 files changed

+319
-326
lines changed

atomai/stat/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .multivar import (imlocal, calculate_transition_matrix,
22
sum_transitions, update_classes)
33
from .fft_nmf import SlidingFFTNMF
4+
from .unmixer import SpectralUnmixer
45

56
__all__ = ['imlocal', 'calculate_transition_matrix', 'sum_transitions',
6-
'update_classes', 'SlidingFFTNMF']
7+
'update_classes', 'SlidingFFTNMF', 'SpectralUnmixer']

atomai/stat/fft_nmf.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,6 @@ def __init__(self, window_size_x=None, window_size_y=None,
3333
self._user_window_step_x = window_step_x
3434
self._user_window_step_y = window_step_y
3535

36-
# These will be set in _calculate_window_params or use defaults
37-
self.window_size_x = window_size_x or 64 # Default fallback
38-
self.window_size_y = window_size_y or 64
39-
self.window_step_x = window_step_x or 16
40-
self.window_step_y = window_step_y or 16
41-
4236
self.interpol_factor = interpolation_factor
4337
self.zoom_factor = zoom_factor
4438
self.hamming_filter = hamming_filter

atomai/stat/unmixer.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import warnings
2+
from sklearn.decomposition import NMF, PCA, FastICA
3+
from sklearn.mixture import GaussianMixture
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
7+
8+
class SpectralUnmixer:
9+
"""
10+
Applies various decomposition algorithms to hyperspectral data for
11+
spectral unmixing and component analysis.
12+
13+
Supported methods: 'nmf', 'pca', 'ica', 'gmm'.
14+
"""
15+
def __init__(self, method: str = 'nmf', n_components: int = 4, normalize: bool = False, **kwargs):
16+
"""
17+
Initializes the unmixer.
18+
19+
Args:
20+
method (str): The decomposition method to use.
21+
Options: 'nmf', 'pca', 'ica', 'gmm'.
22+
n_components (int): The number of components to find.
23+
normalize (bool): If True, each spectrum is L1-normalized (sums to 1)
24+
before decomposition. This is highly recommended for NMF.
25+
**kwargs: Additional keyword arguments to pass to the
26+
underlying sklearn model (e.g., max_iter, pca_dims).
27+
"""
28+
self.method = method
29+
self.n_components = n_components
30+
self.normalize = normalize
31+
self.kwargs = kwargs
32+
33+
if self.method == 'nmf':
34+
self.model = NMF(n_components=n_components, **self.kwargs)
35+
elif self.method == 'pca':
36+
self.model = PCA(n_components=n_components, **self.kwargs)
37+
elif self.method == 'ica':
38+
self.model = FastICA(n_components=n_components, whiten='unit-variance', max_iter=self.kwargs.get("max_iter", 200))
39+
elif self.method == 'gmm':
40+
self.model = GaussianMixture(n_components=n_components, **self.kwargs)
41+
else:
42+
raise ValueError("Method not recognized. Choose from 'nmf', 'pca', 'ica', 'gmm'.")
43+
self.components_ = None
44+
self.abundance_maps_ = None
45+
self.image_shape_ = None
46+
47+
48+
def fit(self, hspy_data: np.ndarray):
49+
"""
50+
Fits the selected model to a hyperspectral data cube.
51+
"""
52+
if hspy_data.ndim != 3:
53+
raise ValueError("Input data must be a 3D hyperspectral cube (h, w, e).")
54+
55+
self.image_shape_ = hspy_data.shape[:2]
56+
h, w, e = hspy_data.shape
57+
spectra_matrix = hspy_data.reshape((h * w, e))
58+
59+
# Data to be passed to the fitting algorithm
60+
spectra_to_fit = spectra_matrix.copy()
61+
62+
# Optional per-spectrum L1 normalization
63+
if self.normalize:
64+
print("Normalizing each spectrum to sum to 1 (L1 norm)...")
65+
# Store norms for later rescaling of abundances
66+
l1_norms = np.sum(spectra_matrix, axis=1, keepdims=True)
67+
# Avoid division by zero for empty spectra (e.g., from outside scan region)
68+
l1_norms[l1_norms == 0] = 1
69+
spectra_to_fit = spectra_matrix / l1_norms
70+
71+
print(f"Fitting data with {self.method.upper()}...")
72+
73+
# NMF non-negativity check
74+
if self.method == 'nmf':
75+
min_val = np.min(spectra_to_fit)
76+
if min_val < 0:
77+
warnings.warn(f"NMF requires non-negative data. Shifting data by {-min_val:.2f}.")
78+
spectra_to_fit = spectra_to_fit - min_val
79+
80+
# GMM's PCA+GMM robust workflow
81+
if self.method == 'gmm':
82+
# PCA dimension selection
83+
pca_param = self.kwargs.get('pca_dims', 0.99) # Default to 99% variance
84+
85+
print("Applying PCA for dimensionality reduction before GMM...")
86+
# First, fit PCA on all components to check variance
87+
pca_full = PCA()
88+
pca_full.fit(spectra_to_fit)
89+
90+
if isinstance(pca_param, int):
91+
n_components_pca = pca_param
92+
print(f"Using a fixed number of {n_components_pca} principal components.")
93+
elif isinstance(pca_param, float) and 0 < pca_param < 1:
94+
cumulative_variance = np.cumsum(pca_full.explained_variance_ratio_)
95+
# Find the number of components needed to reach the threshold
96+
n_components_pca = np.searchsorted(cumulative_variance, pca_param) + 1
97+
explained_var_actual = cumulative_variance[n_components_pca - 1]
98+
print(
99+
f"Found {n_components_pca} components that explain {explained_var_actual:.1%} "
100+
f"of the variance (threshold was {pca_param:.1%})."
101+
)
102+
else:
103+
raise ValueError("pca_dims' must be an int or a float between 0 and 1.")
104+
105+
# Perform the final PCA transformation with the optimal number of components
106+
pca_final = PCA(n_components=n_components_pca)
107+
projected_data = pca_final.fit_transform(spectra_to_fit)
108+
109+
# Fit GMM on the low-dimensional data
110+
self.model.fit(projected_data)
111+
labels = self.model.predict(projected_data)
112+
abundances_unscaled = self.model.predict_proba(projected_data)
113+
self.components_ = np.array([
114+
spectra_matrix[labels == i].mean(axis=0)
115+
for i in range(self.n_components)
116+
])
117+
else: # For NMF, PCA, ICA
118+
abundances_unscaled = self.model.fit_transform(spectra_to_fit)
119+
self.components_ = self.model.components_
120+
121+
# Rescale abundances if data was normalized
122+
if self.normalize:
123+
abundances = abundances_unscaled * l1_norms
124+
else:
125+
abundances = abundances_unscaled
126+
127+
# Reshape abundance maps back to image dimensions
128+
self.abundance_maps_ = abundances.reshape((h, w, self.n_components))
129+
130+
print("Fit complete.")
131+
return self.components_, self.abundance_maps_
132+
133+
def plot_results(self, **kwargs):
134+
if self.components_ is None:
135+
print("You must run .fit() first.")
136+
return
137+
138+
cmap = 'seismic'
139+
cmap = kwargs.get("cmap", cmap)
140+
141+
n_cols = self.n_components
142+
fig, axes = plt.subplots(2, n_cols, figsize=kwargs.get("figsize", (n_cols * 3.5, 6)))
143+
144+
for i in range(self.n_components):
145+
# Plot component spectrum
146+
axes[0, i].plot(self.components_[i, :])
147+
axes[0, i].set_title(f'{self.method.upper()} Component {i+1}')
148+
axes[0, i].set_xlabel('Energy Bin')
149+
if i == 0:
150+
axes[0, i].set_ylabel('Intensity')
151+
152+
# Plot abundance map
153+
ax_map = axes[1, i]
154+
im = ax_map.imshow(self.abundance_maps_[..., i], cmap=cmap)
155+
ax_map.set_title(f'Abundance Map {i+1}')
156+
ax_map.axis('off')
157+
fig.colorbar(im, ax=axes[1, i], fraction=0.046, pad=0.04)
158+
159+
plt.tight_layout()
160+
plt.show()

0 commit comments

Comments
 (0)