|
17 | 17 | from lensless.hardware.trainable_mask import prep_trainable_mask |
18 | 18 | import yaml |
19 | 19 | from lensless.recon.multi_wiener import MultiWiener |
| 20 | +from lensless import SVDeconvNet |
20 | 21 | from huggingface_hub import snapshot_download |
21 | 22 | from collections import OrderedDict |
22 | 23 | from lensless.utils.dataset import MyDataParallel |
|
89 | 90 | "Unet4M+U5+Unet4M_ft_tapecam": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-ft-tapecam", |
90 | 91 | "Unet4M+U5+Unet4M_ft_tapecam_post": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-ft-tapecam-post", |
91 | 92 | "Unet4M+U5+Unet4M_ft_tapecam_pre": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-ft-tapecam-pre", |
| 93 | + # comparing with transformers, with ADAMW optimizer (rest with ADAM) |
| 94 | + "Unet4M+U5+Unet4M_adamw": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-adamw", |
| 95 | + "Unet4M+U5+Unet4M_psfNN_adamw": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psfNN-adamw", |
| 96 | + "U5+Transformer8M": "bezzam/diffusercam-mirflickr-unrolled-admm5-transformer8M", |
| 97 | + "Transformer4M+U5+Transformer4M": "bezzam/diffusercam-mirflickr-transformer4M-unrolled-admm5-transformer4M", |
| 98 | + "Transformer4M+U5+Transformer4M_psfNN": "bezzam/difusercam-mirflickr-transformer4M-unrolled-admm5-transformer4M-psfNN", |
| 99 | + # (~11.6M param) comparing with SVDeconvNet, with ADAMW optimizer (full resolution images) |
| 100 | + "Unet6M+U5+Unet6M_fullres": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-fullres", |
| 101 | + "Unet6M+U5+Unet6M_psfNN_fullres": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psfNN-fullres", |
| 102 | + "SVDecon+UNet8M": "bezzam/diffusercam-mirflickr-svdecon-unet4M", |
| 103 | + "Unet4M+SVDecon+Unet4M": "bezzam/diffusercam-mirflickr-unet4M-svdecon-unet4M", |
92 | 104 | }, |
93 | 105 | "mirflickr_sim": { |
94 | 106 | "Unet4M+U5+Unet4M": "bezzam/diffusercam-mirflickr-sim-unet4M-unrolled-admm5-unet4M", |
|
205 | 217 | "Unet4M+U5+Unet4M_direct_sub": "lensless/multilens-mirflickr-ambient-unet4M-unrolled-admm5-unet4M-direct-sub", |
206 | 218 | "Unet4M+U5+Unet4M_learned_sub": "lensless/multilens-mirflickr-ambient-unet4M-unrolled-admm5-unet4M-learned-sub", |
207 | 219 | "Unet4M+U5+Unet4M_concat": "lensless/multilens-mirflickr-ambient-unet4M-unrolled-admm5-unet4M-concat-ext", |
| 220 | + "Unet4M+U5+Unet4M_concat_psfNN": "lensless/multilens-mirflickr-ambient-unet4M-unrolled-admm5-unet4M-concat-psfNN", |
208 | 221 | "TrainInv+Unet8M": "lensless/multilens-mirflickr-ambient-trainable-inv-unet8M", |
209 | 222 | "TrainInv+Unet8M_learned_sub": "lensless/multilens-mirflickr-ambient-trainable-inv-unet8M-learned-sub", |
210 | 223 | "Unet4M+TrainInv+Unet4M": "lensless/multilens-mirflickr-ambient-unet4M-trainable-inv-unet4M", |
@@ -392,27 +405,37 @@ def load_model( |
392 | 405 |
|
393 | 406 | pre_process, _ = create_process_network( |
394 | 407 | network=config["reconstruction"]["pre_process"]["network"], |
| 408 | + device=device, |
| 409 | + input_background=config["reconstruction"].get("unetres_input_background", False), |
| 410 | + # unetres param |
395 | 411 | depth=config["reconstruction"]["pre_process"]["depth"], |
396 | 412 | nc=config["reconstruction"]["pre_process"]["nc"] |
397 | 413 | if "nc" in config["reconstruction"]["pre_process"].keys() |
398 | 414 | else None, |
399 | | - device=device, |
400 | | - input_background=config["reconstruction"].get("unetres_input_background", False), |
| 415 | + # restormer parameters |
| 416 | + restormer_params=config["reconstruction"]["pre_process"].get( |
| 417 | + "restormer_params", None |
| 418 | + ), |
401 | 419 | ) |
402 | 420 |
|
403 | 421 | if config["reconstruction"]["post_process"]["network"] is not None: |
404 | 422 |
|
405 | 423 | post_process, _ = create_process_network( |
406 | 424 | network=config["reconstruction"]["post_process"]["network"], |
407 | | - depth=config["reconstruction"]["post_process"]["depth"], |
408 | | - nc=config["reconstruction"]["post_process"]["nc"] |
409 | | - if "nc" in config["reconstruction"]["post_process"].keys() |
410 | | - else None, |
411 | 425 | device=device, |
412 | 426 | # get from dict |
413 | 427 | concatenate_compensation=config["reconstruction"]["compensation"][-1] |
414 | 428 | if config["reconstruction"].get("compensation", None) is not None |
415 | 429 | else False, |
| 430 | + # unetres param |
| 431 | + depth=config["reconstruction"]["post_process"]["depth"], |
| 432 | + nc=config["reconstruction"]["post_process"]["nc"] |
| 433 | + if "nc" in config["reconstruction"]["post_process"].keys() |
| 434 | + else None, |
| 435 | + # restormer parameters |
| 436 | + restormer_params=config["reconstruction"]["post_process"].get( |
| 437 | + "restormer_params", None |
| 438 | + ), |
416 | 439 | ) |
417 | 440 |
|
418 | 441 | if train_last_layer: |
@@ -489,6 +512,25 @@ def load_model( |
489 | 512 | ) |
490 | 513 | recon.to(device) |
491 | 514 |
|
| 515 | + elif config["reconstruction"]["method"] == "svdeconvnet": |
| 516 | + psf_learned = psf_learned[0] # remove singleton dimension |
| 517 | + recon = SVDeconvNet( |
| 518 | + psf[0].unsqueeze( |
| 519 | + 0 |
| 520 | + ), # set one of PSF (just for shape) but will be overwritten later |
| 521 | + K=config["reconstruction"]["svdeconvnet"]["K"], |
| 522 | + pre_process=pre_process, |
| 523 | + post_process=post_process, |
| 524 | + background_network=background_network, |
| 525 | + return_intermediate=return_intermediate, |
| 526 | + direct_background_subtraction=config["reconstruction"].get( |
| 527 | + "direct_background_subtraction", False |
| 528 | + ), |
| 529 | + integrated_background_subtraction=config["reconstruction"].get( |
| 530 | + "integrated_background_subtraction", False |
| 531 | + ), |
| 532 | + ) |
| 533 | + |
492 | 534 | if mask is not None: |
493 | 535 | psf_learned = torch.nn.Parameter(psf_learned) |
494 | 536 | recon._set_psf(psf_learned) |
|
0 commit comments