Skip to content

Commit 46c4b5c

Browse files
committed
Add new models.
1 parent 800191d commit 46c4b5c

File tree

4 files changed

+90
-6
lines changed

4 files changed

+90
-6
lines changed

configs/benchmark/diffusercam.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ algorithms: [
3535
# "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M",
3636
"hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psfNN",
3737

38+
# ## comparing UNetRes and Transformer, ADAMW optimizer
39+
# "hf:diffusercam:mirflickr:Transformer4M+U5+Transformer4M",
40+
# "hf:diffusercam:mirflickr:Transformer4M+U5+Transformer4M_psfNN",
41+
# "hf:diffusercam:mirflickr:U5+Transformer8M",
42+
# "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_adamw",
43+
# "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psfNN_adamw",
44+
3845
# # -- benchmark PSF error
3946
# "hf:diffusercam:mirflickr:U5+Unet8M_psf0dB",
4047
# "hf:diffusercam:mirflickr:U5+Unet8M_psf-5dB",
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# python scripts/eval/benchmark_recon.py -cn diffusercam_fullres
2+
defaults:
3+
- defaults
4+
- _self_
5+
6+
dataset: HFDataset
7+
batchsize: 4
8+
device: "cuda:0"
9+
10+
huggingface:
11+
repo: "bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM"
12+
psf: psf.tiff
13+
image_res: null
14+
rotate: False # if measurement is upside-down
15+
alignment: null
16+
downsample: 1
17+
downsample_lensed: 1
18+
flipud: True
19+
flip_lensed: True
20+
single_channel_psf: True
21+
22+
algorithms: [
23+
# "ADMM",
24+
25+
# ## comparing LeADMM5 and SVDeconvNet, ADAMW optimizer
26+
"hf:diffusercam:mirflickr:Unet6M+U5+Unet6M_fullres",
27+
"hf:diffusercam:mirflickr:Unet6M+U5+Unet6M_psfNN_fullres",
28+
"hf:diffusercam:mirflickr:SVDecon+UNet8M",
29+
"hf:diffusercam:mirflickr:Unet4M+SVDecon+Unet4M",
30+
]
31+
32+
save_idx: [0, 1, 3, 4, 8]
33+
n_iter_range: [100] # for ADMM
34+

configs/benchmark/multilens_ambient.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ algorithms: [
2828
# "hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_direct_sub",
2929
# "hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_learned_sub",
3030
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_concat",
31+
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_concat_psfNN",
3132
# "hf:multilens:mirflickr_ambient:TrainInv+Unet8M",
3233
# "hf:multilens:mirflickr_ambient:TrainInv+Unet8M_learned_sub",
3334
# "hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M",

lensless/recon/model_dict.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lensless.hardware.trainable_mask import prep_trainable_mask
1818
import yaml
1919
from lensless.recon.multi_wiener import MultiWiener
20+
from lensless import SVDeconvNet
2021
from huggingface_hub import snapshot_download
2122
from collections import OrderedDict
2223
from lensless.utils.dataset import MyDataParallel
@@ -89,6 +90,17 @@
8990
"Unet4M+U5+Unet4M_ft_tapecam": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-ft-tapecam",
9091
"Unet4M+U5+Unet4M_ft_tapecam_post": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-ft-tapecam-post",
9192
"Unet4M+U5+Unet4M_ft_tapecam_pre": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-ft-tapecam-pre",
93+
# comparing with transformers, with ADAMW optimizer (rest with ADAM)
94+
"Unet4M+U5+Unet4M_adamw": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-adamw",
95+
"Unet4M+U5+Unet4M_psfNN_adamw": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psfNN-adamw",
96+
"U5+Transformer8M": "bezzam/diffusercam-mirflickr-unrolled-admm5-transformer8M",
97+
"Transformer4M+U5+Transformer4M": "bezzam/diffusercam-mirflickr-transformer4M-unrolled-admm5-transformer4M",
98+
"Transformer4M+U5+Transformer4M_psfNN": "bezzam/difusercam-mirflickr-transformer4M-unrolled-admm5-transformer4M-psfNN",
99+
# (~11.6M param) comparing with SVDeconvNet, with ADAMW optimizer (full resolution images)
100+
"Unet6M+U5+Unet6M_fullres": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-fullres",
101+
"Unet6M+U5+Unet6M_psfNN_fullres": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psfNN-fullres",
102+
"SVDecon+UNet8M": "bezzam/diffusercam-mirflickr-svdecon-unet4M",
103+
"Unet4M+SVDecon+Unet4M": "bezzam/diffusercam-mirflickr-unet4M-svdecon-unet4M",
92104
},
93105
"mirflickr_sim": {
94106
"Unet4M+U5+Unet4M": "bezzam/diffusercam-mirflickr-sim-unet4M-unrolled-admm5-unet4M",
@@ -205,6 +217,7 @@
205217
"Unet4M+U5+Unet4M_direct_sub": "lensless/multilens-mirflickr-ambient-unet4M-unrolled-admm5-unet4M-direct-sub",
206218
"Unet4M+U5+Unet4M_learned_sub": "lensless/multilens-mirflickr-ambient-unet4M-unrolled-admm5-unet4M-learned-sub",
207219
"Unet4M+U5+Unet4M_concat": "lensless/multilens-mirflickr-ambient-unet4M-unrolled-admm5-unet4M-concat-ext",
220+
"Unet4M+U5+Unet4M_concat_psfNN": "lensless/multilens-mirflickr-ambient-unet4M-unrolled-admm5-unet4M-concat-psfNN",
208221
"TrainInv+Unet8M": "lensless/multilens-mirflickr-ambient-trainable-inv-unet8M",
209222
"TrainInv+Unet8M_learned_sub": "lensless/multilens-mirflickr-ambient-trainable-inv-unet8M-learned-sub",
210223
"Unet4M+TrainInv+Unet4M": "lensless/multilens-mirflickr-ambient-unet4M-trainable-inv-unet4M",
@@ -392,27 +405,37 @@ def load_model(
392405

393406
pre_process, _ = create_process_network(
394407
network=config["reconstruction"]["pre_process"]["network"],
408+
device=device,
409+
input_background=config["reconstruction"].get("unetres_input_background", False),
410+
# unetres param
395411
depth=config["reconstruction"]["pre_process"]["depth"],
396412
nc=config["reconstruction"]["pre_process"]["nc"]
397413
if "nc" in config["reconstruction"]["pre_process"].keys()
398414
else None,
399-
device=device,
400-
input_background=config["reconstruction"].get("unetres_input_background", False),
415+
# restormer parameters
416+
restormer_params=config["reconstruction"]["pre_process"].get(
417+
"restormer_params", None
418+
),
401419
)
402420

403421
if config["reconstruction"]["post_process"]["network"] is not None:
404422

405423
post_process, _ = create_process_network(
406424
network=config["reconstruction"]["post_process"]["network"],
407-
depth=config["reconstruction"]["post_process"]["depth"],
408-
nc=config["reconstruction"]["post_process"]["nc"]
409-
if "nc" in config["reconstruction"]["post_process"].keys()
410-
else None,
411425
device=device,
412426
# get from dict
413427
concatenate_compensation=config["reconstruction"]["compensation"][-1]
414428
if config["reconstruction"].get("compensation", None) is not None
415429
else False,
430+
# unetres param
431+
depth=config["reconstruction"]["post_process"]["depth"],
432+
nc=config["reconstruction"]["post_process"]["nc"]
433+
if "nc" in config["reconstruction"]["post_process"].keys()
434+
else None,
435+
# restormer parameters
436+
restormer_params=config["reconstruction"]["post_process"].get(
437+
"restormer_params", None
438+
),
416439
)
417440

418441
if train_last_layer:
@@ -489,6 +512,25 @@ def load_model(
489512
)
490513
recon.to(device)
491514

515+
elif config["reconstruction"]["method"] == "svdeconvnet":
516+
psf_learned = psf_learned[0] # remove singleton dimension
517+
recon = SVDeconvNet(
518+
psf[0].unsqueeze(
519+
0
520+
), # set one of PSF (just for shape) but will be overwritten later
521+
K=config["reconstruction"]["svdeconvnet"]["K"],
522+
pre_process=pre_process,
523+
post_process=post_process,
524+
background_network=background_network,
525+
return_intermediate=return_intermediate,
526+
direct_background_subtraction=config["reconstruction"].get(
527+
"direct_background_subtraction", False
528+
),
529+
integrated_background_subtraction=config["reconstruction"].get(
530+
"integrated_background_subtraction", False
531+
),
532+
)
533+
492534
if mask is not None:
493535
psf_learned = torch.nn.Parameter(psf_learned)
494536
recon._set_psf(psf_learned)

0 commit comments

Comments
 (0)