Skip to content

Commit 7d8ea58

Browse files
authored
Merge pull request #15 from saic-mdal/celeba
Celeba
2 parents a1301cc + 5b3a6c2 commit 7d8ea58

11 files changed

+30318
-5
lines changed

README.md

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,29 @@ Docker: TODO
194194
195195
## CelebA
196196
On the host machine:
197+
# Make shure you are in lama folder
198+
cd lama
199+
export TORCH_HOME=$(pwd) && export PYTHONPATH=.
197200
198-
TODO: download & prepare
199-
TODO: trian
200-
TODO: eval
201+
# Download CelebA-HQ dataset
202+
# Download data256x256.zip from https://drive.google.com/drive/folders/11Vz0fqHS2rXDb5pprgTjpD7S2BAJhi1P
203+
204+
# unzip & split into train/test/visualization & create config for it
205+
bash fetch_data/celebahq_dataset_prepare.sh
206+
207+
# generate masks for test and viz at the end of epoch
208+
bash fetch_data/celebahq_gen_masks.sh
209+
210+
# Run training
211+
# You can change bs with data.batch_size=10
212+
python bin/train.py -cn lama-fourier-celeba location=celeba
213+
214+
# Infer model on thick/thin/medium masks in 256 and run evaluation
215+
# like this:
216+
python3 bin/predict.py \
217+
model.path=$(pwd)/experiments/<user>_<date:time>_lama-fourier-celeba_/ \
218+
indir=$(pwd)/celeba-hq-dataset/visual_test_256/random_thick_256/ \
219+
outdir=$(pwd)/inference/celeba_random_thick_256 model.checkpoint=last.ckpt
201220
202221
203222
Docker: TODO
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
run_title: ''
2+
3+
visualizer:
4+
kind: directory
5+
outdir: ./visualizer-output/celeba/train_ablv2_work_no_segmpl_csdilirpl_celeba_csdilirpl1_new/samples
6+
key_order:
7+
- image
8+
- predicted_image
9+
- discr_output_fake
10+
- discr_output_real
11+
- inpainted
12+
rescale_keys:
13+
- discr_output_fake
14+
- discr_output_real
15+
trainer:
16+
kwargs:
17+
gpus: -1
18+
accelerator: ddp
19+
max_epochs: 40
20+
gradient_clip_val: 1
21+
log_gpu_memory: None
22+
limit_train_batches: 25000
23+
val_check_interval: 2600
24+
log_every_n_steps: 250
25+
precision: 32
26+
terminate_on_nan: false
27+
check_val_every_n_epoch: 1
28+
num_sanity_val_steps: 8
29+
replace_sampler_ddp: false
30+
benchmark: true
31+
resume_from_checkpoint: /group-volume/User-Driven-Content-Generation/e.logacheva/CelebA-HQ-inpainting/experiments/e.logacheva_2021-10-12_21-37-20_train_ablv2_work_no_segmpl_csdilirpl_celeba_csdilirpl1_new/models/last.ckpt
32+
checkpoint_kwargs:
33+
verbose: true
34+
save_top_k: 5
35+
save_last: true
36+
period: 1
37+
monitor: val_ssim_fid100_f1_total_mean
38+
mode: max
39+
training_model:
40+
kind: default
41+
visualize_each_iters: 1000
42+
concat_mask: true
43+
store_discr_outputs_for_vis: true
44+
losses:
45+
l1:
46+
weight_missing: 0
47+
weight_known: 10
48+
perceptual:
49+
weight: 0
50+
adversarial:
51+
kind: r1
52+
weight: 10
53+
gp_coef: 0.001
54+
mask_as_fake_target: true
55+
allow_scale_mask: true
56+
feature_matching:
57+
weight: 100
58+
segm_pl:
59+
weight: 1
60+
imagenet_weights: true
61+
optimizers:
62+
generator:
63+
kind: adam
64+
lr: 0.001
65+
discriminator:
66+
kind: adam
67+
lr: 0.0001
68+
69+
defaults:
70+
- location: celeba
71+
- data: abl-04-256-mh-dist-celeba
72+
- evaluator: default_inpainted
73+
- generator: pix2pixhd_global_sigmoid
74+
- discriminator: pix2pixhd_nlayer
75+
- hydra: overrides
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
run_title: ''
2+
3+
visualizer:
4+
kind: directory
5+
outdir: ./visualizer-output/celeba/train_ablv2_work_no_segmpl_csirpl_celeba_csirpl03_new/samples
6+
key_order:
7+
- image
8+
- predicted_image
9+
- discr_output_fake
10+
- discr_output_real
11+
- inpainted
12+
rescale_keys:
13+
- discr_output_fake
14+
- discr_output_real
15+
trainer:
16+
kwargs:
17+
gpus: -1
18+
accelerator: ddp
19+
max_epochs: 40
20+
gradient_clip_val: 1
21+
log_gpu_memory: None
22+
limit_train_batches: 25000
23+
val_check_interval: 2600
24+
log_every_n_steps: 250
25+
precision: 32
26+
terminate_on_nan: false
27+
check_val_every_n_epoch: 1
28+
num_sanity_val_steps: 8
29+
replace_sampler_ddp: false
30+
checkpoint_kwargs:
31+
verbose: true
32+
save_top_k: 5
33+
save_last: true
34+
period: 1
35+
monitor: val_ssim_fid100_f1_total_mean
36+
mode: max
37+
training_model:
38+
kind: default
39+
visualize_each_iters: 1000
40+
concat_mask: true
41+
store_discr_outputs_for_vis: true
42+
losses:
43+
l1:
44+
weight_missing: 0
45+
weight_known: 10
46+
perceptual:
47+
weight: 0
48+
adversarial:
49+
kind: r1
50+
weight: 10
51+
gp_coef: 0.001
52+
mask_as_fake_target: true
53+
allow_scale_mask: true
54+
feature_matching:
55+
weight: 100
56+
segm_pl:
57+
weight: 0.3
58+
arch_encoder: resnet50
59+
imagenet_weights: true
60+
optimizers:
61+
generator:
62+
kind: adam
63+
lr: 0.001
64+
discriminator:
65+
kind: adam
66+
lr: 0.0001
67+
68+
defaults:
69+
- location: celeba
70+
- data: abl-04-256-mh-dist-celeba
71+
- evaluator: default_inpainted
72+
- generator: pix2pixhd_global_sigmoid
73+
- discriminator: pix2pixhd_nlayer
74+
- hydra: overrides
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
run_title: ''
2+
3+
visualizer:
4+
kind: directory
5+
outdir: ./visualizer-output/celeba/train_ablv2_work_no_segmpl_vgg_celeba_l2_vgg003_new/samples
6+
key_order:
7+
- image
8+
- predicted_image
9+
- discr_output_fake
10+
- discr_output_real
11+
- inpainted
12+
rescale_keys:
13+
- discr_output_fake
14+
- discr_output_real
15+
trainer:
16+
kwargs:
17+
gpus: -1
18+
accelerator: ddp
19+
max_epochs: 40
20+
gradient_clip_val: 1
21+
log_gpu_memory: None
22+
limit_train_batches: 25000
23+
val_check_interval: 2600
24+
log_every_n_steps: 250
25+
precision: 32
26+
terminate_on_nan: false
27+
check_val_every_n_epoch: 1
28+
num_sanity_val_steps: 8
29+
replace_sampler_ddp: false
30+
checkpoint_kwargs:
31+
verbose: true
32+
save_top_k: 5
33+
save_last: true
34+
period: 1
35+
monitor: val_ssim_fid100_f1_total_mean
36+
mode: max
37+
training_model:
38+
kind: default
39+
visualize_each_iters: 1000
40+
concat_mask: true
41+
store_discr_outputs_for_vis: true
42+
losses:
43+
l1:
44+
weight_missing: 0
45+
weight_known: 10
46+
perceptual:
47+
weight: 0.03
48+
kwargs:
49+
metric: l2
50+
adversarial:
51+
kind: r1
52+
weight: 10
53+
gp_coef: 0.001
54+
mask_as_fake_target: true
55+
allow_scale_mask: true
56+
feature_matching:
57+
weight: 100
58+
segm_pl:
59+
weight: 0
60+
optimizers:
61+
generator:
62+
kind: adam
63+
lr: 0.001
64+
discriminator:
65+
kind: adam
66+
lr: 0.0001
67+
68+
defaults:
69+
- location: celeba
70+
- data: abl-04-256-mh-dist-celeba
71+
- evaluator: default_inpainted
72+
- generator: pix2pixhd_global_sigmoid
73+
- discriminator: pix2pixhd_nlayer
74+
- hydra: overrides

configs/training/data/abl-04-256-mh-dist-celeba.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# @package _group_
22

3-
batch_size: 10
3+
batch_size: 5
44
val_batch_size: 3
55
num_workers: 3
66

configs/training/lama-regular.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@ defaults:
6161
- data: abl-04-256-mh-dist
6262
- evaluator: default_inpainted
6363
- trainer: any_gpu_large_ssim_ddp_final
64-
- hydra: overrides
64+
- hydra: overrides
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# @package _group_
2+
data_root_dir: /home/user/lama/celeba-hq-dataset/
3+
out_root_dir: /home/user/lama/experiments/
4+
tb_dir: /home/user/lama/tb_logs/
5+
pretrained_models: /home/user/lama/
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
mkdir celeba-hq-dataset
2+
3+
unzip data256x256.zip -d celeba-hq-dataset/
4+
5+
# Reindex
6+
for i in `echo {00001..30000}`
7+
do
8+
mv 'celeba-hq-dataset/data256x256/'$i'.jpg' 'celeba-hq-dataset/data256x256/'$[10#$i - 1]'.jpg'
9+
done
10+
11+
12+
# Split: split train -> train & val
13+
cat fetch_data/train_shuffled.flist | shuf > celeba-hq-dataset/temp_train_shuffled.flist
14+
cat celeba-hq-dataset/temp_train_shuffled.flist | head -n 2000 > celeba-hq-dataset/val_shuffled.flist
15+
cat celeba-hq-dataset/temp_train_shuffled.flist | tail -n +2001 > celeba-hq-dataset/train_shuffled.flist
16+
cat fetch_data/val_shuffled.flist > celeba-hq-dataset/visual_test_shuffled.flist
17+
18+
mkdir celeba-hq-dataset/train_256/
19+
mkdir celeba-hq-dataset/val_source_256/
20+
mkdir celeba-hq-dataset/visual_test_source_256/
21+
22+
cat celeba-hq-dataset/train_shuffled.flist | xargs -I {} mv celeba-hq-dataset/data256x256/{} celeba-hq-dataset/train_256/
23+
cat celeba-hq-dataset/val_shuffled.flist | xargs -I {} mv celeba-hq-dataset/data256x256/{} celeba-hq-dataset/val_source_256/
24+
cat celeba-hq-dataset/visual_test_shuffled.flist | xargs -I {} mv celeba-hq-dataset/data256x256/{} celeba-hq-dataset/visual_test_source_256/
25+
26+
27+
# create location config celeba.yaml
28+
PWD=$(pwd)
29+
DATASET=${PWD}/celeba-hq-dataset
30+
CELEBA=${PWD}/configs/training/location/celeba.yaml
31+
32+
touch $CELEBA
33+
echo "# @package _group_" >> $CELEBA
34+
echo "data_root_dir: ${DATASET}/" >> $CELEBA
35+
echo "out_root_dir: ${PWD}/experiments/" >> $CELEBA
36+
echo "tb_dir: ${PWD}/tb_logs/" >> $CELEBA
37+
echo "pretrained_models: ${PWD}/" >> $CELEBA

fetch_data/celebahq_gen_masks.sh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
python3 bin/gen_mask_dataset.py \
2+
$(pwd)/configs/data_gen/random_thick_256.yaml \
3+
celeba-hq-dataset/val_source_256/ \
4+
celeba-hq-dataset/val_256/random_thick_256/
5+
6+
python3 bin/gen_mask_dataset.py \
7+
$(pwd)/configs/data_gen/random_thin_256.yaml \
8+
celeba-hq-dataset/val_source_256/ \
9+
celeba-hq-dataset/val_256/random_thin_256/
10+
11+
python3 bin/gen_mask_dataset.py \
12+
$(pwd)/configs/data_gen/random_medium_256.yaml \
13+
celeba-hq-dataset/val_source_256/ \
14+
celeba-hq-dataset/val_256/random_medium_256/
15+
16+
python3 bin/gen_mask_dataset.py \
17+
$(pwd)/configs/data_gen/random_thick_256.yaml \
18+
celeba-hq-dataset/visual_test_source_256/ \
19+
celeba-hq-dataset/visual_test_256/random_thick_256/
20+
21+
python3 bin/gen_mask_dataset.py \
22+
$(pwd)/configs/data_gen/random_thin_256.yaml \
23+
celeba-hq-dataset/visual_test_source_256/ \
24+
celeba-hq-dataset/visual_test_256/random_thin_256/
25+
26+
python3 bin/gen_mask_dataset.py \
27+
$(pwd)/configs/data_gen/random_medium_256.yaml \
28+
celeba-hq-dataset/visual_test_source_256/ \
29+
celeba-hq-dataset/visual_test_256/random_medium_256/

0 commit comments

Comments
 (0)