Skip to content

Commit 5103736

Browse files
authored
【Hackathon 8th No.23】Improved Training of Wasserstein GANs 论文复现 (#1146)
* Create wgangp.yaml * Add files via upload * Add files via upload * Add files via upload * Delete examples/wgangp/conf/wgangp.yaml * Update wgangp_cifar10.yaml * Update functions.py * Update wgangp_cifar10.py * Update model.py * Update wgangp_cifar10.py * Update wgangp_mnist.py * Update wgangp_toy.py Update wgangp_toy.py * Update functions.py Update functions.py * Update model.py * Update wgangp_cifar10.py * Update wgangp_mnist.py * Update wgangp_toy.py * Update wgangp_cifar10.yaml * Update model.py * Update wgan_gp.md * Update wgangp_cifar10.yaml * Update model.py * Update wgan_gp.md * Update wgan_gp.md * Update wgan_gp.md * Add files via upload * Update wgan_gp.md * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Update wgangp_cifar10.yaml * Update wgangp_cifar10.yaml * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Delete examples/wgangp directory * Create wgangp_cifar10.yaml * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Delete examples/wgangp_toy_model.py * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Delete docs/zh/examples/wgan_gp.md * Delete docs/index.md * Delete mkdocs.yml * Add files via upload * Add files via upload * Add files via upload
1 parent 4b994d8 commit 5103736

10 files changed

+2151
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
defaults:
2+
- ppsci_default
3+
- TRAIN: train_default
4+
- TRAIN/ema: ema_default
5+
- TRAIN/swa: swa_default
6+
- EVAL: eval_default
7+
- INFER: infer_default
8+
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
9+
- _self_
10+
11+
hydra:
12+
run:
13+
# dynamic output directory according to running time and override name
14+
dir: outputs_wgangp_cifar10/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
15+
job:
16+
name: ${mode} # name of logfile
17+
chdir: false # keep current working directory unchanged
18+
sweep:
19+
# output directory for multirun
20+
dir: ${hydra.run.dir}
21+
subdir: ./
22+
23+
# general settings
24+
mode: eval # running mode: train/eval
25+
output_dir: ${hydra:run.dir}
26+
seed: 42
27+
28+
# model settings
29+
MODEL:
30+
gen_net:
31+
input_keys: [ "labels" ]
32+
output_keys: [ "fake_data" ]
33+
dim: 128
34+
output_dim: 3072
35+
label_num: 10
36+
use_label: true
37+
dis_net:
38+
input_keys: [ "data", "labels" ]
39+
output_keys: [ "disc_fake", "disc_acgan" ]
40+
dim: 128
41+
label_num: 10
42+
use_label: true
43+
44+
# logger settings
45+
LOGGER:
46+
name: wgangp_cifar10
47+
level: INFO
48+
log_file: wgangp_cifar10.log
49+
50+
DATA:
51+
input_keys: [ "labels" ]
52+
label_keys: [ "real_data" ]
53+
data_path: ./data/cifar-10-python.tar.gz
54+
55+
# visualization settings
56+
VIS:
57+
vis: true
58+
batch: 16
59+
num: 64
60+
61+
LOSS:
62+
gen:
63+
acgan_scale_g: 0.1
64+
dis:
65+
acgan_scale: 1
66+
67+
# training settings
68+
TRAIN:
69+
dataset:
70+
"name": "NamedArrayDataset"
71+
sampler:
72+
name: "BatchSampler"
73+
shuffle: true
74+
drop_last: true
75+
optimizer:
76+
learning_rate: 2e-4
77+
beta1: 0.
78+
beta2: 0.9
79+
lr_scheduler_gen:
80+
epochs: 100000
81+
iters_per_epoch: 1
82+
learning_rate: 2e-4
83+
end_lr: 0.0
84+
by_epoch: true
85+
lr_scheduler_dis:
86+
epochs: 100000
87+
iters_per_epoch: 5
88+
learning_rate: 2e-4
89+
end_lr: 0.0
90+
by_epoch: true
91+
batch_size: 64
92+
use_shared_memory: true
93+
num_workers: 0
94+
epochs: 100000
95+
epochs_dis: 1
96+
iters_per_epoch_dis: 5
97+
epochs_gen: 1
98+
iters_per_epoch_gen: 1
99+
drop_last: true
100+
pretrained_gen_model_path: null
101+
pretrained_dis_model_path: null
102+
103+
# evaluation settings
104+
EVAL:
105+
dataset:
106+
"name": "NamedArrayDataset"
107+
inceptionscore:
108+
eps: 0
109+
splits: 10
110+
batch_size: 64
111+
batch_size: 64
112+
use_shared_memory: true
113+
num_workers: 0
114+
pretrained_gen_model_path: null
115+
pretrained_dis_model_path: null
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
defaults:
2+
- ppsci_default
3+
- TRAIN: train_default
4+
- TRAIN/ema: ema_default
5+
- TRAIN/swa: swa_default
6+
- EVAL: eval_default
7+
- INFER: infer_default
8+
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
9+
- _self_
10+
11+
hydra:
12+
run:
13+
# dynamic output directory according to running time and override name
14+
dir: outputs_wgangp_mnist/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
15+
job:
16+
name: ${mode} # name of logfile
17+
chdir: false # keep current working directory unchanged
18+
sweep:
19+
# output directory for multirun
20+
dir: ${hydra.run.dir}
21+
subdir: ./
22+
23+
# general settings
24+
mode: eval # running mode: train/eval
25+
output_dir: ${hydra:run.dir}
26+
seed: 42
27+
28+
# model settings
29+
MODEL:
30+
gen_net:
31+
output_keys: [ "fake_data" ]
32+
dim: 64
33+
output_dim: 784
34+
dis_net:
35+
input_keys: [ "data" ]
36+
output_keys: [ "score" ]
37+
dim: 64
38+
39+
# logger settings
40+
LOGGER:
41+
name: wgangp_mnist
42+
level: INFO
43+
log_file: wgangp_mnist.log
44+
45+
DATA:
46+
input_keys: [ "real_data" ]
47+
data_path: ./data/mnist.pkl.gz
48+
49+
# visualization settings
50+
VIS:
51+
vis: true
52+
batch: 10
53+
54+
LOSS:
55+
dis:
56+
lamda: 10
57+
58+
# training settings
59+
TRAIN:
60+
dataset:
61+
"name": "NamedArrayDataset"
62+
sampler:
63+
name: "BatchSampler"
64+
shuffle: true
65+
drop_last: true
66+
optimizer:
67+
learning_rate: 1e-4
68+
beta1: 0.5
69+
beta2: 0.9
70+
batch_size: 100
71+
use_shared_memory: true
72+
num_workers: 0
73+
epochs: 200000
74+
epochs_dis: 1
75+
iters_per_epoch_dis: 5
76+
epochs_gen: 1
77+
iters_per_epoch_gen: 1
78+
drop_last: true
79+
pretrained_gen_model_path: null
80+
pretrained_dis_model_path: null
81+
82+
# evaluation settings
83+
EVAL:
84+
dataset:
85+
"name": "NamedArrayDataset"
86+
batch_size: 1
87+
use_shared_memory: true
88+
num_workers: 0
89+
pretrained_gen_model_path: null
90+
pretrained_dis_model_path: null

examples/wgangp/conf/wgangp_toy.yaml

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
defaults:
2+
- ppsci_default
3+
- TRAIN: train_default
4+
- TRAIN/ema: ema_default
5+
- TRAIN/swa: swa_default
6+
- EVAL: eval_default
7+
- INFER: infer_default
8+
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
9+
- _self_
10+
11+
hydra:
12+
run:
13+
# dynamic output directory according to running time and override name
14+
dir: outputs_wgangp_toy/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
15+
job:
16+
name: ${mode} # name of logfile
17+
chdir: false # keep current working directory unchanged
18+
sweep:
19+
# output directory for multirun
20+
dir: ${hydra.run.dir}
21+
subdir: ./
22+
23+
# general settings
24+
mode: eval # running mode: train/eval
25+
output_dir: ${hydra:run.dir}
26+
seed: 42
27+
28+
# model settings
29+
MODEL:
30+
gen_net:
31+
output_keys: [ "fake_data" ]
32+
dim: 512
33+
dis_net:
34+
input_keys: [ "data" ]
35+
output_keys: [ "score" ]
36+
dim: 512
37+
38+
# logger settings
39+
LOGGER:
40+
name: wgangp_toy
41+
level: INFO
42+
log_file: wgangp_toy.log
43+
44+
DATA:
45+
input_keys: [ "real_data" ]
46+
mode: 8gaussians #swissroll/8gaussians/25gaussians
47+
48+
# visualization settings
49+
VIS:
50+
vis: true
51+
52+
LOSS:
53+
dis:
54+
lamda: 0.1
55+
56+
# training settings
57+
TRAIN:
58+
dataset:
59+
"name": "NamedArrayDataset"
60+
sampler:
61+
name: "BatchSampler"
62+
shuffle: true
63+
drop_last: true
64+
optimizer:
65+
learning_rate: 1e-4
66+
beta1: 0.5
67+
beta2: 0.9
68+
batch_size: 8192
69+
use_shared_memory: true
70+
num_workers: 0
71+
epochs: 3125
72+
epochs_dis: 1
73+
iters_per_epoch_dis: 5
74+
epochs_gen: 1
75+
iters_per_epoch_gen: 1
76+
drop_last: true
77+
pretrained_gen_model_path: null
78+
pretrained_dis_model_path: null
79+
80+
# evaluation settings
81+
EVAL:
82+
dataset:
83+
"name": "NamedArrayDataset"
84+
batch_size: 8192
85+
use_shared_memory: true
86+
num_workers: 0
87+
pretrained_gen_model_path: null
88+
pretrained_dis_model_path: null

0 commit comments

Comments
 (0)