Optimal Eye Surgeon (ICML-2024)
This repository contains the source code for pruning image generator networks at initialization to alleviate overfitting.
Repository structure:
📦
├─ baselines
│ ├─ baseline_pai.py
│ ├─ baseline_pat.py
│ ├─ sgld.py
│ ├─ vanilla_decoder.py
│ └─ vanilla_dip.py
├─ images
├─ configs
├─ sparse_models
│ ├─ baboon
│ ├─ barbara
│ ├─ lena
│ └─ pepper
├─ sparse_models_imp
│ ├─ baboon
│ ├─ barbara
│ ├─ lena
│ └─ pepper
├─ src
│ ├─ models
│ └─ utils
├─ dip_mask.py
├─ train_sparse.py
└─ transfer.py
Install conda, create and activate environment and install required packages
conda create --name oes python==3.7.16
conda activate oes
pip install -r requirements.txt && pip install -e .Please run OES_demo_comparison.ipynb to see how OES prevents overfitting in comparison to other methods. (Approximate runtime ~ 10 mins)
Run impvsoes_comparison.ipynb to compare OES masks at initialization and IMP masks at convergence. (Approximate runtime ~ 7 mins)
Working with the code to reproduce results for each finding in the paper:
The following code implements the above optimization using Gumbel softmax reparameterization trick to find sparse network with 5% weights remaining with a noisy pepper image:
python dip_mask.py --sparsity=0.05 --image_name="pepper"to generate supermasks at various sparsity levels as follows
After obtaining a mask by the above procedure, run the following to train the sparse network on the image. The sparse network alleviates overfitting:
python train_sparse.py -f configs/config_train_sparse.yamlFor comparing with baselines
Run the following command for dense DIP
python baselines/vanilla_dip.py -f configs/config_vanilla_dip.yamlRun the following command for deep-decoder
python baselines/vanilla_decoder.py -f configs/config_vanilla_decoder.yamland the command for SGLD
python baselines/sgld.py -f configs/config_sgld.yamlFor OES mask transfer, use the following command:
python transfer.py --trans_type="pai" --transferimage_name="pepper" --image_name="lena"For IMP mask transfer, use the following command:
python transfer.py --trans_type="pat" --transferimage_name="pepper" --image_name="lena"python baselines/baseline_pai.py --image_name="pepper" --prune_type="grasp_local" --sparse=0.9Chose among the following options for prune_type:
rand_globalrand_localmag_globalsnipsnip_localgraspgrasp_localsynflowsynflow_local
python baselines/baseline_pat.py --image_name="pepper" --prune_iters=14 --percent=0.2The above line runs IMP for 14 iterations with 20% deletion of weights at each iteration. Resulting in 5% sparsity. (drastic pruning degrades performance)
If you use this code, consider citing our work:
@inproceedings{ghosh2024optimal,
title={Optimal Eye Surgeon: Finding image priors through sparse generators at initialization},
author={Ghosh, Avrajit and Zhang, Xitong and Sun, Kenneth K and Qu, Qing and Ravishankar, Saiprasad and Wang, Rongrong},
booktitle={Forty-first International Conference on Machine Learning},
year={2024}
}


