Skip to content

Commit 00dc639

Browse files
committed
added training code and pretrained acid models
1 parent 2b88813 commit 00dc639

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+77449
-41
lines changed

README.md

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ downloaded the first time they are required. Specify an output path using
6868

6969
```
7070
> braindance.py -h
71-
usage: braindance.py [-h] [--model {re_impl_nodepth,re_impl_depth}] [--video [VIDEO]] [path]
71+
usage: braindance.py [-h] [--model {re_impl_nodepth,re_impl_depth,ac_impl_nodepth,ac_impl_depth}] [--video [VIDEO]] [path]
7272
7373
What's up, BD-maniacs?
7474
@@ -85,11 +85,97 @@ positional arguments:
8585
8686
optional arguments:
8787
-h, --help show this help message and exit
88-
--model {re_impl_nodepth,re_impl_depth}
88+
--model {re_impl_nodepth,re_impl_depth,ac_impl_nodepth,ac_impl_depth}
8989
pretrained model to use.
9090
--video [VIDEO] path to write video recording to. (no recording if unspecified).
9191
```
9292

93+
## Training
94+
95+
### Data Preparation
96+
97+
We support training on [RealEstate10K](https://google.github.io/realestate10k/)
98+
and [ACID](https://infinite-nature.github.io/). Both come in the same [format as
99+
described here](https://google.github.io/realestate10k/download.html) and the
100+
preparation is the same for both of them. You will need to have
101+
[`colmap`](https://github.com/colmap/colmap) installed and available on your
102+
`$PATH`.
103+
104+
We assume that you have extracted the `.txt` files of the dataset you want to
105+
prepare into `$TXT_ROOT`, e.g. for RealEstate:
106+
107+
```
108+
> tree $TXT_ROOT
109+
├── test
110+
│   ├── 000c3ab189999a83.txt
111+
│   ├── ...
112+
│   └── fff9864727c42c80.txt
113+
└── train
114+
├── 0000cc6d8b108390.txt
115+
├── ...
116+
└── ffffe622a4de5489.txt
117+
```
118+
119+
and that you have downloaded the frames (we downloaded them in resolution `640
120+
x 360`) into `$IMG_ROOT`, e.g. for RealEstate:
121+
122+
```
123+
> tree $IMG_ROOT
124+
├── test
125+
│   ├── 000c3ab189999a83
126+
│   │   ├── 45979267.png
127+
│   │   ├── ...
128+
│   │   └── 55255200.png
129+
│   ├── ...
130+
│   ├── 0017ce4c6a39d122
131+
│   │   ├── 40874000.png
132+
│   │   ├── ...
133+
│   │   └── 48482000.png
134+
├── train
135+
│   ├── ...
136+
```
137+
138+
To prepare the `$SPLIT` split of the dataset (`$SPLIT` being one of `train`,
139+
`test` for RealEstate and `train`, `test`, `validation` for ACID) in
140+
`$SPA_ROOT`, run the following within the `scripts` directory:
141+
142+
```
143+
python sparse_from_realestate_format.py --txt_src ${TXT_ROOT}/${SPLIT} --img_src ${IMG_ROOT}/${SPLIT} --spa_dst ${SPA_ROOT}/${SPLIT}
144+
```
145+
146+
You can also simply set `TXT_ROOT`, `IMG_ROOT` and `SPA_ROOT` as environment
147+
variables and run `./sparsify_realestate.sh` or `./sparsify_acid.sh`. Take a
148+
look into the sources to run with multiple workers in parallel.
149+
150+
Finally, symlink `$SPA_ROOT` to `data/realestate_sparse`/`data/acid_sparse`.
151+
152+
### First Stage Models
153+
As described in [our paper](https://arxiv.org/abs/2104.07652), we train the transformer models in
154+
a compressed, discrete latent space of pretrained VQGANs. These pretrained models can be conveniently
155+
downloaded by running
156+
```
157+
python scripts/download_vqmodels.py
158+
```
159+
which will also create symlinks ensuring that the paths specified in the training configs (see `configs/*`) exist.
160+
In case some of the models have already been downloaded, the script will only create the symlinks.
161+
162+
For training custom first stage models, we refer to the [taming transformers
163+
repository](https://github.com/CompVis/taming-transformers).
164+
165+
### Running the Training
166+
After both the preparation of the data and the first stage models are done,
167+
the experiments on ACID and RealEstate10K as described in our paper can be reproduced by running
168+
```
169+
python geofree/main.py --base configs/<dataset>/<dataset>_13x23_<experiment>.yaml -t --gpus 0,
170+
```
171+
where `<dataset>` is one of `realestate`/`acid` and `<experiment>` is one of
172+
`expl_img`/`expl_feat`/`expl_emb`/`impl_catdepth`/`impl_depth`/`impl_nodepth`/`hybrid`.
173+
These abbreviations correspond to the experiments listed in the following Table (see also Fig.2 in the main paper)
174+
175+
![variants](assets/geofree_variants.png)
176+
177+
Note that each experiment was conducted on a GPU with 40 GB VRAM.
178+
93179
## BibTeX
94180

95181
```

assets/geofree_variants.png

167 KB
Loading

assets/rooms_scenic_01_wkr.jpg

136 KB
Loading
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
model:
2+
base_learning_rate: 0.0625
3+
target: geofree.models.transformers.warpgpt.WarpTransformer
4+
params:
5+
plot_cond_stage: True
6+
monitor: "val/loss"
7+
8+
use_scheduler: True
9+
scheduler_config:
10+
target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
11+
params:
12+
verbosity_interval: 0 # 0 or negative to disable
13+
warm_up_steps: 5000
14+
max_decay_steps: 500001
15+
lr_start: 2.5e-6
16+
lr_max: 1.5e-4
17+
lr_min: 1.0e-8
18+
19+
transformer_config:
20+
target: geofree.modules.transformer.mingpt.WarpGPT
21+
params:
22+
vocab_size: 16384
23+
block_size: 597 # conditioning + 299 - 1
24+
n_unmasked: 299 # 299 cond embeddings
25+
n_layer: 32
26+
n_head: 16
27+
n_embd: 1024
28+
warper_config:
29+
target: geofree.modules.transformer.warper.ConvWarper
30+
params:
31+
size: [13, 23]
32+
33+
first_stage_config:
34+
target: geofree.models.vqgan.VQModel
35+
params:
36+
ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
37+
embed_dim: 256
38+
n_embed: 16384
39+
ddconfig:
40+
double_z: False
41+
z_channels: 256
42+
resolution: 256
43+
in_channels: 3
44+
out_ch: 3
45+
ch: 128
46+
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
47+
num_res_blocks: 2
48+
attn_resolutions: [ 16 ]
49+
dropout: 0.0
50+
lossconfig:
51+
target: geofree.modules.losses.vqperceptual.DummyLoss
52+
53+
cond_stage_config: "__is_first_stage__"
54+
55+
data:
56+
target: geofree.main.DataModuleFromConfig
57+
params:
58+
# bs 8 and accumulate_grad_batches 2 for 34gb vram
59+
batch_size: 8
60+
num_workers: 16
61+
train:
62+
target: geofree.data.acid.ACIDSparseTrain
63+
params:
64+
size:
65+
- 208
66+
- 368
67+
68+
validation:
69+
target: geofree.data.acid.ACIDCustomTest
70+
params:
71+
size:
72+
- 208
73+
- 368
74+
75+
lightning:
76+
trainer:
77+
accumulate_grad_batches: 2
78+
benchmark: True
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
model:
2+
base_learning_rate: 0.0625
3+
target: geofree.models.transformers.net2net.WarpingFeatureTransformer
4+
params:
5+
plot_cond_stage: True
6+
monitor: "val/loss"
7+
8+
use_scheduler: True
9+
scheduler_config:
10+
target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
11+
params:
12+
verbosity_interval: 0 # 0 or negative to disable
13+
warm_up_steps: 5000
14+
max_decay_steps: 500001
15+
lr_start: 2.5e-6
16+
lr_max: 1.5e-4
17+
lr_min: 1.0e-8
18+
19+
transformer_config:
20+
target: geofree.modules.transformer.mingpt.GPT
21+
params:
22+
vocab_size: 16384
23+
block_size: 597 # conditioning + 299 - 1
24+
n_unmasked: 299 # 299 cond embeddings
25+
n_layer: 32
26+
n_head: 16
27+
n_embd: 1024
28+
29+
first_stage_key:
30+
x: "dst_img"
31+
32+
cond_stage_key:
33+
c: "src_img"
34+
points: "src_points"
35+
R: "R_rel"
36+
t: "t_rel"
37+
K: "K"
38+
K_inv: "K_inv"
39+
40+
first_stage_config:
41+
target: geofree.models.vqgan.VQModel
42+
params:
43+
ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
44+
embed_dim: 256
45+
n_embed: 16384
46+
ddconfig:
47+
double_z: False
48+
z_channels: 256
49+
resolution: 256
50+
in_channels: 3
51+
out_ch: 3
52+
ch: 128
53+
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
54+
num_res_blocks: 2
55+
attn_resolutions: [ 16 ]
56+
dropout: 0.0
57+
lossconfig:
58+
target: geofree.modules.losses.vqperceptual.DummyLoss
59+
60+
cond_stage_config: "__is_first_stage__"
61+
62+
data:
63+
target: geofree.main.DataModuleFromConfig
64+
params:
65+
# bs 8 and accumulate_grad_batches 2 for 34gb vram
66+
batch_size: 8
67+
num_workers: 16
68+
train:
69+
target: geofree.data.acid.ACIDSparseTrain
70+
params:
71+
size:
72+
- 208
73+
- 368
74+
75+
validation:
76+
target: geofree.data.acid.ACIDCustomTest
77+
params:
78+
size:
79+
- 208
80+
- 368
81+
82+
lightning:
83+
trainer:
84+
accumulate_grad_batches: 2
85+
benchmark: True
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
model:
2+
base_learning_rate: 0.0625
3+
target: geofree.models.transformers.net2net.WarpingTransformer
4+
params:
5+
plot_cond_stage: True
6+
monitor: "val/loss"
7+
8+
use_scheduler: True
9+
scheduler_config:
10+
target: geofree.lr_scheduler.LambdaWarmUpCosineScheduler
11+
params:
12+
verbosity_interval: 0 # 0 or negative to disable
13+
warm_up_steps: 5000
14+
max_decay_steps: 500001
15+
lr_start: 2.5e-6
16+
lr_max: 1.5e-4
17+
lr_min: 1.0e-8
18+
19+
transformer_config:
20+
target: geofree.modules.transformer.mingpt.GPT
21+
params:
22+
vocab_size: 16384
23+
block_size: 597 # conditioning + 299 - 1
24+
n_unmasked: 299 # 299 cond embeddings
25+
n_layer: 32
26+
n_head: 16
27+
n_embd: 1024
28+
29+
first_stage_config:
30+
target: geofree.models.vqgan.VQModel
31+
params:
32+
ckpt_path: "pretrained_models/acid_first_stage/last.ckpt"
33+
embed_dim: 256
34+
n_embed: 16384
35+
ddconfig:
36+
double_z: False
37+
z_channels: 256
38+
resolution: 256
39+
in_channels: 3
40+
out_ch: 3
41+
ch: 128
42+
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
43+
num_res_blocks: 2
44+
attn_resolutions: [ 16 ]
45+
dropout: 0.0
46+
lossconfig:
47+
target: geofree.modules.losses.vqperceptual.DummyLoss
48+
49+
data:
50+
target: geofree.main.DataModuleFromConfig
51+
params:
52+
# bs 8 and accumulate_grad_batches 2 for 34gb vram
53+
batch_size: 8
54+
num_workers: 16
55+
train:
56+
target: geofree.data.acid.ACIDSparseTrain
57+
params:
58+
size:
59+
- 208
60+
- 368
61+
62+
validation:
63+
target: geofree.data.acid.ACIDCustomTest
64+
params:
65+
size:
66+
- 208
67+
- 368
68+
69+
lightning:
70+
trainer:
71+
accumulate_grad_batches: 2
72+
benchmark: True

0 commit comments

Comments
 (0)