Skip to content

Commit 0db3fc7

Browse files
committed
Update demo and add ICME files.
1 parent 5c7f81c commit 0db3fc7

File tree

11 files changed

+850
-72
lines changed

11 files changed

+850
-72
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ DiffuserCam_Mirflickr_200_3011302021_11h43_seed11*
88
paper/paper.pdf
99
data/*
1010
models/*
11+
multirun/*
1112
notebooks/models/*
13+
authenticate_admm/*
14+
authenticate_learned/*
1215
*.png
1316
*.jpg
1417
*.npy

configs/benchmark/diffusercam.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
# python scripts/eval/benchmark_recon.py -cn diffusercam
2+
#
3+
# To sweep over multiple hyperparameters, e.g. of mu1 and mu2 of ADMM:
4+
# python scripts/eval/benchmark_recon.py -cn diffusercam -m algorithms=[ADMM] n_iter_range=[10,100] admm.mu1=1e-6,1e-5 admm.mu2=1e-6,1e-5 admm.mu3=3e-5 admm.tau=1e-4
5+
#
6+
# Hydra will do a Cartesian product of mu1 and mu2, i.e. it will run 4 experiments.
7+
# Output will be saved to `multirun` folder.
28
defaults:
39
- defaults
410
- _self_
@@ -92,3 +98,9 @@ n_iter_range: [10] # for ADMM
9298
# save_idx: [0, 1, 3, 4, 8, 45, 58, 63]
9399
# n_iter_range: [100] # for ADMM
94100

101+
admm:
102+
mu1: 1e-6
103+
mu2: 1e-5
104+
mu3: 4e-5
105+
tau: 0.0001
106+

configs/digicam_config.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ rpi:
99
device: adafruit
1010
virtual: False
1111
save: True
12+
preview: False
1213

1314
# pattern: data/psf/adafruit_random_pattern_20230719.npy
1415
pattern: random
@@ -20,9 +21,7 @@ radius: 20 # if pattern: circ
2021
center: [0, 0]
2122

2223

23-
aperture:
24-
center: null
25-
shape: null
24+
aperture: null
2625
# aperture:
2726
# center: [59,76]
2827
# shape: [19,26]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# python scripts/security/digicam_mirflickr_psf_err.py
2+
defaults:
3+
- defaults
4+
- _self_
5+
6+
cache_dir: null
7+
metrics_fp : null
8+
# metrics_fp: /root/LenslessPiCam/learned.json
9+
hf_repo: null # by default use one in model config
10+
11+
# set model
12+
# -- for learning-based methods (comment if using ADMM)
13+
model: Unet4M+U5+Unet4M_wave_psfNN
14+
15+
# # -- for ADMM with fixed parameters
16+
# model: admm
17+
n_iter: 10
18+
19+
device: cuda:1
20+
save_idx: [1, 2, 4, 5, 9]
21+
only_save_idx: False # whether to save only the specified indices
22+
n_files: null
23+
percent_pixels_wrong: [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
24+
plot_vs_percent_wrong: False # whether to plot again percent wrong or correct
25+
flip: False # whether to flip mask values (True) or reset them (False)
26+
27+
compare_aes: [128, 256] # key lengths
28+
digicam_ratio: 0.6 # approximate ratio of pixels that need to be correct
29+
bit_depth: 8
30+
n_pixel: 1404

lensless/recon/gd.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class GradientDescent(ReconstructionAlgorithm):
6464
Object for applying projected gradient descent.
6565
"""
6666

67-
def __init__(self, psf, dtype=None, proj=non_neg, **kwargs):
67+
def __init__(self, psf, dtype=None, proj=non_neg, lip_fact=1.8, **kwargs):
6868
"""
6969
7070
Parameters
@@ -83,6 +83,7 @@ def __init__(self, psf, dtype=None, proj=non_neg, **kwargs):
8383

8484
assert callable(proj)
8585
self._proj = proj
86+
self._lip_fact = lip_fact
8687
super(GradientDescent, self).__init__(psf, dtype, **kwargs)
8788

8889
if self._denoiser is not None:
@@ -106,7 +107,9 @@ def reset(self):
106107
# set step size as < 2 / lipschitz
107108
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
108109
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
109-
self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values)
110+
self._alpha = torch.real(
111+
self._lip_fact / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values
112+
)
110113

111114
else:
112115
if self._initial_est is not None:
@@ -120,7 +123,7 @@ def reset(self):
120123
# set step size as < 2 / lipschitz
121124
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
122125
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
123-
self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0))
126+
self._alpha = np.real(self._lip_fact / np.max(Hadj_flat * H_flat, axis=0))
124127

125128
def _grad(self):
126129
diff = self._convolver.convolve(self._image_est) - self._data

lensless/utils/io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,6 @@ def load_image(
166166
if flip_lr:
167167
img = np.fliplr(img)
168168

169-
if verbose:
170-
print_image_info(img)
171-
172169
if bg is not None:
173170

174171
# if bg is float vector, turn into int-valued vector
@@ -204,6 +201,9 @@ def load_image(
204201
dtype = original_dtype
205202
img = img.astype(dtype)
206203

204+
if verbose:
205+
print_image_info(img)
206+
207207
return img
208208

209209

lensless/utils/plot.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def plot_cross_section(
230230
ax.set_xlim([-half_width, half_width])
231231
ax.grid()
232232

233+
ax.set_title("Cross-section")
233234
if dB and plot_db_drop:
234235
cross_section -= np.max(cross_section)
235236
zero_crossings = np.where(np.diff(np.signbit(cross_section + plot_db_drop)))[0]
@@ -244,8 +245,6 @@ def plot_cross_section(
244245
width = 2 * np.abs(first_crossing)
245246
ax.axvline(x=-first_crossing, c="k", linestyle="--")
246247
ax.axvline(x=+first_crossing, c="k", linestyle="--")
247-
248-
ax.set_title("Cross-section")
249248
ax.set_xlabel(f"-{plot_db_drop}dB width = {width}")
250249

251250
else:
@@ -299,7 +298,7 @@ def plot_autocorr2d(vals, pad_mode="reflect", ax=None):
299298
return ax, autocorr
300299

301300

302-
def plot_autocorr_rgb(img, width=3, figsize=None, plot_psf=False, psf_gamma=2.2):
301+
def plot_autocorr_rgb(img, width=3, figsize=None, plot_psf=False, psf_gamma=2.2, verbose=False):
303302
"""
304303
Plot autocorrelation of each channel of an image.
305304
@@ -340,13 +339,15 @@ def plot_autocorr_rgb(img, width=3, figsize=None, plot_psf=False, psf_gamma=2.2)
340339
idx = max_idx[0]
341340
# ax_auto[1][i].axhline(y=idx, c=c, linestyle="--")
342341

343-
ax, _ = plot_cross_section(
342+
ax, cross_section = plot_cross_section(
344343
autocorr_c,
345344
idx=idx,
346345
color=c,
347346
ax=ax_auto[2 if plot_psf else 1][i],
348347
plot_db_drop=width,
349348
)
349+
if verbose:
350+
print(f"Maximum drop in {c} channel: {cross_section.max() - cross_section.min()}")
350351
if i != 0:
351352
ax.set_ylabel("")
352353
return ax

notebook/lenslesspicam_demo.ipynb

Lines changed: 249 additions & 58 deletions
Large diffs are not rendered by default.

scripts/hardware/config_digicam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def config_digicam(config):
103103
with warnings.catch_warnings():
104104
warnings.simplefilter("ignore")
105105
s = slm.create(device)
106-
s._show_preview(pattern)
106+
if config.preview:
107+
s._show_preview(pattern)
107108
plt.savefig("preview.png")
108109

109110

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
This script is related to a patent that has been filed.
3+
4+
Please contact the EPFL Technology Transfer Office (https://tto.epfl.ch/, info.tto@epfl.ch) for licensing inquiries.
5+
6+
----
7+
8+
These script computes ROC curves for lensless authentication.
9+
10+
For this script, install:
11+
```
12+
pip install scikit-learn seaborn
13+
```
14+
ROC curve docs: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
15+
16+
"""
17+
18+
import numpy as np
19+
from sklearn import metrics
20+
import matplotlib.pyplot as plt
21+
import json
22+
import pandas as pd
23+
import seaborn as sn
24+
25+
font_scale = 2.3
26+
plt.rcParams.update({"font.size": 30})
27+
lw = 5 # linewidth
28+
linestyles = ["--", "-.", ":"]
29+
30+
# scores_paths = {
31+
# "ADMM10": "/root/LenslessPiCam/outputs/2024-03-25/23-36-06/scores_10_grayscaleTrue_down1_nfiles10000.json",
32+
# "ADMM25": "/root/LenslessPiCam/outputs/2024-03-26/17-52-49/scores_25_grayscaleTrue_down1_nfiles10000.json",
33+
# "ADMM50": "/root/LenslessPiCam/outputs/2024-03-27/10-49-08/scores_50_grayscaleTrue_down1_nfiles10000.json",
34+
# }
35+
36+
# scores_paths = {
37+
# "Data fid.": {
38+
# # "path": "/root/LenslessPiCam/outputs/2024-12-07/20-26-06/scores_Unet4M+U5+Unet4M_wave_psfNN_down1_nfiles3750_metricrecon.txt",
39+
# "path": "/root/LenslessPiCam/authenticate_learned/data_fid/scores_Unet4M+U5+Unet4M_wave_psfNN_down1_nfiles3750_metricrecon.txt",
40+
# "invert": True, # if lower score is True
41+
# },
42+
# # "MSE": {
43+
# # # "path": "/root/LenslessPiCam/outputs/2024-12-07/22-12-52/scores_Unet4M+U5+Unet4M_wave_psfNN_down1_nfiles3750_metricmse.txt",
44+
# # "path": "/root/LenslessPiCam/authenticate_learned/mse/scores_Unet4M+U5+Unet4M_wave_psfNN_down1_nfiles3750_metricmse.txt",
45+
# # "invert": True, # if lower score is True
46+
# # },
47+
# "LPIPS": {
48+
# # "path": "/root/LenslessPiCam/outputs/2024-12-07/18-23-12/scores_Unet4M+U5+Unet4M_wave_psfNN_down1_nfiles3750_metriclpips.txt",
49+
# "path": "/root/LenslessPiCam/authenticate_learned/lpips/scores_Unet4M+U5+Unet4M_wave_psfNN_down1_nfiles3750_metriclpips.txt",
50+
# "invert": True, # if lower score is True
51+
# },
52+
# }
53+
scores_paths = {
54+
"Data fid.": {
55+
# "path": "/root/LenslessPiCam/outputs/2024-12-08/07-17-49/scores_admm100_down1_nfiles3750_metricrecon.txt",
56+
"path": "/root/LenslessPiCam/authenticate_admm/recon/scores_admm100_down1_nfiles3750_metricrecon.txt",
57+
"invert": True, # if lower score is True
58+
},
59+
# "MSE": {
60+
# "path": "/root/LenslessPiCam/outputs/2024-12-08/19-53-17/scores_admm100_down1_nfiles3750_metricmse.txt",
61+
# "invert": True, # if lower score is True
62+
# },
63+
"LPIPS": {
64+
# "path": "/root/LenslessPiCam/outputs/2024-12-07/18-26-43/scores_admm100_down1_nfiles3750_metriclpips.txt",
65+
"path": "/root/LenslessPiCam/authenticate_admm/lpips/scores_admm100_down1_nfiles3750_metriclpips.txt",
66+
"invert": True, # if lower score is True
67+
},
68+
}
69+
70+
print_incorrect = False
71+
72+
# TODO way to get this without loading dataset?
73+
n_files_per_mask = 250
74+
mask_labels = list(np.arange(15)) * n_files_per_mask
75+
mask_labels = np.array(mask_labels)
76+
77+
# initialize figure
78+
fig, ax = plt.subplots()
79+
for method, scores_dict in scores_paths.items():
80+
print(f"--- Processing {method}...")
81+
scores_fp = scores_dict["path"]
82+
invert = scores_dict["invert"]
83+
84+
scores = []
85+
with open(scores_fp, "r") as f:
86+
for line in f:
87+
scores.append(json.loads(line))
88+
scores = np.array(scores)
89+
n_psf = len(scores)
90+
n_files = len(scores[0])
91+
92+
# compute and plot confusion matrix
93+
confusion_matrix = np.zeros((n_psf, n_psf))
94+
accuracy = np.zeros(n_psf)
95+
incorrect = dict()
96+
n_incorrect = 0
97+
y_true = [] # for ROC curve
98+
y_score = [] # for ROC curve
99+
for psf_idx in range(n_psf):
100+
101+
source_psf_mask = mask_labels == psf_idx
102+
confusion_matrix[psf_idx] = np.mean(np.array(scores[:, source_psf_mask]), axis=1)
103+
104+
# for ROC curve
105+
y_true += list(source_psf_mask)
106+
y_score += list(scores[psf_idx])
107+
108+
# compute accuracy for each PSF
109+
detected_mask = np.argmin(scores[:, source_psf_mask], axis=0)
110+
if print_incorrect:
111+
print(f"PSF {psf_idx} detected as: ", detected_mask)
112+
accuracy[int(psf_idx)] = np.mean(detected_mask == int(psf_idx))
113+
if accuracy[int(psf_idx)] < 1:
114+
incorrect_idx = np.where(detected_mask != int(psf_idx))[0]
115+
116+
# reconvert idx back to original idx
117+
incorrect_idx = np.array([np.where(source_psf_mask)[0][i] for i in incorrect_idx])
118+
incorrect[int(psf_idx)] = [int(i) for i in incorrect_idx]
119+
n_incorrect += len(incorrect_idx)
120+
121+
total_accuracy = np.mean(accuracy)
122+
print("Total accuracy: ", total_accuracy)
123+
print("Number of incorrect detections: ", n_incorrect)
124+
125+
#### FOR OLD ADMM SCORES
126+
# # load scores
127+
# with open(scores_fp, "r") as f:
128+
# scores = json.load(f)
129+
#
130+
# # prepare scores
131+
# y_true = []
132+
# y_score = []
133+
# n_psf = len(scores)
134+
# accuracy = np.zeros(n_psf)
135+
# confusion_matrix = np.zeros((n_psf, n_psf))
136+
# for psf_idx in scores:
137+
# y_true_idx = np.ones(n_psf)
138+
# y_true_idx[int(psf_idx)] = 0
139+
# for score in scores[psf_idx]:
140+
# y_true += list(y_true_idx)
141+
# y_score += list(score)
142+
143+
# # confusion matrix
144+
# confusion_matrix[int(psf_idx)] = np.mean(np.array(scores[psf_idx]), axis=0)
145+
146+
# # compute accuracy for each PSF
147+
# detected_mask = np.argmin(scores[psf_idx], axis=1)
148+
# accuracy[int(psf_idx)] = np.mean(detected_mask == int(psf_idx))
149+
150+
# total_accuracy = np.mean(accuracy)
151+
# print(f"Total accuracy ({method}): {total_accuracy:.2f}")
152+
153+
# compute and plot confusion matrix
154+
df_cm = pd.DataFrame(
155+
confusion_matrix, index=[i for i in range(n_psf)], columns=[i for i in range(n_psf)]
156+
)
157+
plt.figure(figsize=(10, 7))
158+
# set font scale
159+
sn.set(font_scale=font_scale)
160+
sn.heatmap(df_cm, annot=False, cbar=True, xticklabels=5, yticklabels=5)
161+
confusion_fn = f"confusion_matrix_{method}.png"
162+
plt.savefig(confusion_fn, bbox_inches="tight")
163+
print(f"Confusion matrix saved as {confusion_fn}")
164+
165+
# compute the ROC curve
166+
y_true = np.array(y_true).astype(bool)
167+
y_score = np.array(y_score)
168+
if invert:
169+
y_score = -1 * y_score
170+
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score)
171+
auc = metrics.roc_auc_score(y_true, y_score)
172+
173+
# create ROC curve
174+
ax.plot(fpr, tpr, label=f"{method}, AUC={auc:.2f}", linewidth=lw, linestyle=linestyles.pop())
175+
176+
177+
# set axis font size
178+
ax.set_ylabel("True Positive Rate")
179+
ax.set_xlabel("False Positive Rate")
180+
ax.legend()
181+
ax.grid()
182+
183+
184+
# save ROC curve
185+
plt.tight_layout()
186+
fig.savefig("roc_curve.png", bbox_inches="tight")

0 commit comments

Comments
 (0)