Skip to content

Commit 91dfdd7

Browse files
authored
Merge pull request #34 from CosmoStat/gmca
Test added in mr_gmca, running well, comparing GMCA when using wavelet or curvelet
2 parents 48824bc + c86baa1 commit 91dfdd7

File tree

2 files changed

+412
-1
lines changed

2 files changed

+412
-1
lines changed

pycs/sparsity/sparse2d/bss_eval.py

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
"""
2+
Created on Mar 30, 2015
3+
4+
@author: Ming Jiang and Jean-Luc Starck
5+
6+
Routines for GMCA evaluation
7+
"""
8+
9+
import numpy as np
10+
from os import remove
11+
from subprocess import check_call
12+
from subprocess import call
13+
from datetime import datetime
14+
from astropy.io import fits
15+
import shlex
16+
from pycs.misc.cosmostat_init import *
17+
from pycs.misc.cosmostat_init import writefits
18+
from skimage import data, color
19+
from skimage.transform import resize
20+
21+
def load_source_images(shape=(128, 128)):
22+
# Load grayscale source images and resize
23+
img1 = resize(data.camera(), shape, anti_aliasing=True)
24+
img2 = color.rgb2gray(resize(data.astronaut(), shape, anti_aliasing=True))
25+
26+
# Normalize images
27+
img1 = (img1 - np.mean(img1)) / np.std(img1)
28+
img2 = (img2 - np.mean(img2)) / np.std(img2)
29+
30+
# Stack into a single array: (2, H, W)
31+
sources = np.stack([img1, img2], axis=0)
32+
return sources
33+
34+
def mix_sources_images(sources):
35+
n_sources, H, W = sources.shape
36+
n_pixels = H * W
37+
38+
# Flatten sources: (2, H*W)
39+
S = sources.reshape(n_sources, n_pixels)
40+
41+
# Create random mixing matrix: (3, 2)
42+
A = np.random.randn(3, 2)
43+
44+
# Mix: (3, H*W)
45+
mixed = A @ S
46+
47+
# Reshape to image form: (3, H, W)
48+
mixed_images = mixed.reshape(3, H, W)
49+
return mixed_images, A
50+
51+
52+
def mix_sources_images_noise(sources, noise_level=0.05):
53+
n_sources, H, W = sources.shape
54+
n_pixels = H * W
55+
56+
# Flatten source images
57+
S = sources.reshape(n_sources, n_pixels)
58+
59+
# Generate random mixing matrix
60+
A = np.random.randn(3, 2)
61+
62+
# Mix sources
63+
mixed = A @ S # Shape: (3, n_pixels)
64+
65+
# Add Gaussian noise
66+
noise = np.random.normal(scale=noise_level, size=mixed.shape)
67+
mixed += noise
68+
69+
# Reshape back to images
70+
mixed_images = mixed.reshape(3, H, W)
71+
return mixed_images, A
72+
73+
74+
def reorder_and_fix_sign(true_sources, estimated_sources):
75+
"""
76+
Reorder and apply sign correction to estimated_sources so they match true_sources.
77+
78+
Parameters:
79+
true_sources: np.ndarray of shape (n_sources, H, W)
80+
estimated_sources: np.ndarray of shape (n_sources, H, W)
81+
82+
Returns:
83+
corrected_sources: np.ndarray of shape (n_sources, H, W)
84+
"""
85+
n_sources, H, W = true_sources.shape
86+
S_true = true_sources.reshape(n_sources, -1)
87+
S_est = estimated_sources.reshape(n_sources, -1)
88+
89+
# Normalize
90+
S_true = (S_true - S_true.mean(axis=1, keepdims=True))
91+
S_true /= np.linalg.norm(S_true, axis=1, keepdims=True)
92+
S_est = (S_est - S_est.mean(axis=1, keepdims=True))
93+
S_est /= np.linalg.norm(S_est, axis=1, keepdims=True)
94+
95+
# Correlation matrix
96+
corr = S_true @ S_est.T # (n_true, n_est)
97+
98+
# Reorder and sign-correct
99+
used = set()
100+
corrected_sources = np.zeros_like(true_sources)
101+
102+
for i in range(n_sources):
103+
idx = np.argmax(np.abs(corr[i]))
104+
while idx in used:
105+
corr[i, idx] = 0
106+
idx = np.argmax(np.abs(corr[i]))
107+
used.add(idx)
108+
sign = np.sign(corr[i, idx])
109+
corrected_sources[i] = sign * estimated_sources[idx]
110+
111+
return corrected_sources
112+
113+
def compute_sdr(true_sources, estimated_sources):
114+
"""
115+
Compute SDR for each pair of true and estimated sources.
116+
117+
Parameters:
118+
true_sources: np.ndarray of shape (n_sources, H, W)
119+
estimated_sources: np.ndarray of shape (n_sources, H, W)
120+
121+
Returns:
122+
sdr_values: list of SDR values for each source
123+
"""
124+
sdr_values = []
125+
for i in range(true_sources.shape[0]):
126+
s_true = true_sources[i].flatten()
127+
s_est = estimated_sources[i].flatten()
128+
noise = s_true - s_est
129+
sdr = 10 * np.log10(np.sum(s_true ** 2) / np.sum(noise ** 2))
130+
sdr_values.append(sdr)
131+
return sdr_values
132+
133+
134+
def amari_error(A_true, A_est):
135+
"""
136+
Compute the Amari error between the true and estimated mixing matrices.
137+
138+
Parameters:
139+
A_true: np.ndarray (n_obs, n_sources) — ground truth mixing matrix
140+
A_est: np.ndarray (n_obs, n_sources) — estimated mixing matrix
141+
142+
Returns:
143+
Amari error (float)
144+
"""
145+
try:
146+
# Estimate the unmixing matrix
147+
W_est = np.linalg.pinv(A_est) # shape: (n_sources, n_obs)
148+
G = W_est @ A_true # shape: (n_sources, n_sources)
149+
except np.linalg.LinAlgError:
150+
return np.inf
151+
152+
# DEBUG
153+
if G.shape[0] != G.shape[1]:
154+
raise ValueError(f"G should be square, but got shape {G.shape}. Check input shapes.")
155+
156+
G = np.abs(G)
157+
row_sums = np.sum(G, axis=1, keepdims=True)
158+
col_sums = np.sum(G, axis=0, keepdims=True)
159+
160+
row_error = np.sum(np.sum(G / row_sums, axis=1) - 1)
161+
col_error = np.sum(np.sum(G / col_sums, axis=0) - 1)
162+
163+
return (row_error + col_error) / (2 * G.shape[0])
164+
165+
# Metrics
166+
167+
def evaluate(A0, S0, A, S, corrPerm=False):
168+
"""Computes the NMSE and the CA.
169+
170+
Parameters
171+
----------
172+
A0: np.ndarray
173+
(m,n) float array, ground truth mixing matrix
174+
S0: np.ndarray
175+
(n,p) float array, ground truth sources
176+
A: np.ndarray
177+
(m,n) float array, estimated mixing matrix
178+
S: np.ndarray
179+
(n,p) float array, estimated sources
180+
corrPerm: bool
181+
correct permutation of A and S (in-place updates)
182+
perScale: bool
183+
calculate NMSE per wavelet scale
184+
nscales: int
185+
number of wavelet detail scales
186+
S0wt: np.ndarray
187+
(m,n,nscales+1) float array, wavelet transform of S0, optional (to accelerate)
188+
189+
Returns
190+
-------
191+
(float,float) or (float,float,np.ndarray)
192+
CA,
193+
NMSE,
194+
NMSE per scale if perScale ((nscales+1,) float array)
195+
"""
196+
197+
if not corrPerm:
198+
A = A.copy()
199+
S = S.copy()
200+
201+
n = np.shape(A0)[1]
202+
203+
corr_perm(A0, S0, A, S, inplace=True)
204+
205+
# CA = -10 * np.log10(np.mean(np.abs(np.dot(np.linalg.pinv(A), A0) - np.eye(n))))
206+
CA = (np.mean(np.abs(np.dot(np.linalg.pinv(A), A0) - np.eye(n))))
207+
# NMSE = -10 * np.log10(np.sum((S0-S)**2)/np.sum(S0**2))
208+
NMSE = (np.sum((S0-S)**2)/np.sum(S0**2))
209+
210+
return CA, NMSE
211+
212+
213+
214+
def corr_perm(A0, S0, A, S, inplace=False, optInd=False):
215+
"""Correct the permutation of the solution.
216+
217+
Parameters
218+
----------
219+
A0: np.ndarray
220+
(m,n) float array, ground truth mixing matrix
221+
S0: np.ndarray
222+
(n,p) float array, ground truth sources
223+
A: np.ndarray
224+
(m,n) float array, estimated mixing matrix
225+
S: np.ndarray
226+
(n,p) float array, estimated sources
227+
inplace: bool
228+
in-place update of A and S
229+
optInd: bool
230+
return permutation
231+
232+
Returns
233+
-------
234+
None or np.ndarray or (np.ndarray,np.ndarray) or (np.ndarray,np.ndarray,np.ndarray)
235+
A (if not inplace),
236+
S (if not inplace),
237+
ind (if optInd)
238+
"""
239+
240+
A0 = A0.copy()
241+
S0 = S0.copy()
242+
if not inplace:
243+
A = A.copy()
244+
S = S.copy()
245+
246+
n = np.shape(A0)[1]
247+
248+
for i in range(0, n):
249+
S[i, :] *= (1e-24 + np.linalg.norm(A[:, i]))
250+
A[:, i] /= (1e-24 + np.linalg.norm(A[:, i]))
251+
S0[i, :] *= (1e-24 + np.linalg.norm(A0[:, i]))
252+
A0[:, i] /= (1e-24 + np.linalg.norm(A0[:, i]))
253+
254+
try:
255+
diff = abs(np.dot(np.linalg.inv(np.dot(A0.T, A0)), np.dot(A0.T, A)))
256+
except np.linalg.LinAlgError:
257+
diff = abs(np.dot(np.linalg.pinv(A0), A))
258+
print('Warning! Pseudo-inverse used.')
259+
260+
ind = np.arange(0, n)
261+
262+
for i in range(0, n):
263+
ind[i] = np.where(diff[i, :] == max(diff[i, :]))[0][0]
264+
265+
A[:] = A[:, ind.astype(int)]
266+
S[:] = S[ind.astype(int), :]
267+
268+
for i in range(0, n):
269+
p = np.sum(S[i, :] * S0[i, :])
270+
if p < 0:
271+
S[i, :] = -S[i, :]
272+
A[:, i] = -A[:, i]
273+
274+
if inplace and not optInd:
275+
return None
276+
elif inplace and optInd:
277+
return ind
278+
elif not optInd:
279+
return A, S
280+
else:
281+
return A, S, ind
282+
283+
284+
def nmse(S0, S):
285+
"""Compute the normalized mean square error (NMSE) in dB.
286+
287+
Parameters
288+
----------
289+
S0: np.ndarray
290+
(n,p) float array, ground truth sources
291+
S: np.ndarray
292+
(n,p) float array, estimated sources
293+
294+
Returns
295+
-------
296+
float
297+
NMSE (dB)
298+
"""
299+
return -10 * np.log10(np.sum((S0-S)**2)/np.sum(S0**2))
300+
301+
302+
def ca(A0, A):
303+
"""Compute the criterion on A (CA) in dB.
304+
305+
Parameters
306+
----------
307+
A0: np.ndarray
308+
(m,n) float array, ground truth mixing matrix
309+
A: np.ndarray
310+
(m,n) float array, estimated mixing matrix
311+
312+
Returns
313+
-------
314+
float
315+
CA (dB)
316+
"""
317+
return -10 * np.log10(np.mean(np.abs(np.dot(np.linalg.pinv(A), A0) - np.eye(np.shape(A0)[1]))))
318+
319+
320+
321+

0 commit comments

Comments
 (0)