Skip to content

Commit 6b8ac42

Browse files
authored
Merge pull request #276 from plyfager/ada-support
[Feature] support ada module and training
2 parents 0c5643e + 68ab951 commit 6b8ac42

15 files changed

+1615
-17
lines changed

configs/styleganv3/metafile.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Models:
1111
In Collection: StyleGANv3
1212
Metadata:
1313
Training Data: FFHQ
14-
Name: stylegan3_gamma32.8
14+
Name: stylegan3_noaug
1515
Results:
1616
- Dataset: FFHQ
1717
Metrics:
@@ -24,7 +24,7 @@ Models:
2424
In Collection: StyleGANv3
2525
Metadata:
2626
Training Data: Others
27-
Name: stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8
27+
Name: stylegan3_ada
2828
Results:
2929
- Dataset: Others
3030
Metrics:
@@ -37,7 +37,7 @@ Models:
3737
In Collection: StyleGANv3
3838
Metadata:
3939
Training Data: FFHQ
40-
Name: stylegan3_gamma2.0
40+
Name: stylegan3_t
4141
Results:
4242
- Dataset: FFHQ
4343
Metrics:
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
_base_ = [
2+
'../_base_/models/stylegan/stylegan3_base.py',
3+
'../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py'
4+
]
5+
6+
synthesis_cfg = {
7+
'type': 'SynthesisNetwork',
8+
'channel_base': 65536,
9+
'channel_max': 1024,
10+
'magnitude_ema_beta': 0.999,
11+
'conv_kernel': 1,
12+
'use_radial_filters': True
13+
}
14+
r1_gamma = 3.3 # set by user
15+
d_reg_interval = 16
16+
17+
load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb_20220329_234933-ac0500a1.pth' # noqa
18+
19+
# ada settings
20+
aug_kwargs = {
21+
'xflip': 1,
22+
'rotate90': 1,
23+
'xint': 1,
24+
'scale': 1,
25+
'rotate': 1,
26+
'aniso': 1,
27+
'xfrac': 1,
28+
'brightness': 1,
29+
'contrast': 1,
30+
'lumaflip': 1,
31+
'hue': 1,
32+
'saturation': 1
33+
}
34+
35+
model = dict(
36+
type='StaticUnconditionalGAN',
37+
generator=dict(
38+
out_size=1024,
39+
img_channels=3,
40+
rgb2bgr=True,
41+
synthesis_cfg=synthesis_cfg),
42+
discriminator=dict(
43+
type='ADAStyleGAN2Discriminator',
44+
in_size=1024,
45+
input_bgr2rgb=True,
46+
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)),
47+
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
48+
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval))
49+
50+
imgs_root = 'data/metfaces/images/'
51+
data = dict(
52+
samples_per_gpu=4,
53+
train=dict(dataset=dict(imgs_root=imgs_root)),
54+
val=dict(imgs_root=imgs_root))
55+
56+
ema_half_life = 10. # G_smoothing_kimg
57+
58+
ema_kimg = 10
59+
ema_nimg = ema_kimg * 1000
60+
ema_beta = 0.5**(32 / max(ema_nimg, 1e-8))
61+
62+
custom_hooks = [
63+
dict(
64+
type='VisualizeUnconditionalSamples',
65+
output_dir='training_samples',
66+
interval=5000),
67+
dict(
68+
type='ExponentialMovingAverageHook',
69+
module_keys=('generator_ema', ),
70+
interp_mode='lerp',
71+
interp_cfg=dict(momentum=ema_beta),
72+
interval=1,
73+
start_iter=0,
74+
priority='VERY_HIGH')
75+
]
76+
77+
inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl'
78+
metrics = dict(
79+
fid50k=dict(
80+
type='FID',
81+
num_images=50000,
82+
inception_pkl=inception_pkl,
83+
inception_args=dict(type='StyleGAN'),
84+
bgr2rgb=True))
85+
86+
evaluation = dict(
87+
type='GenerativeEvalHook',
88+
interval=dict(milestones=[100000], interval=[10000, 5000]),
89+
metrics=dict(
90+
type='FID',
91+
num_images=50000,
92+
inception_pkl=inception_pkl,
93+
inception_args=dict(type='StyleGAN'),
94+
bgr2rgb=True),
95+
sample_kwargs=dict(sample_model='ema'))
96+
97+
lr_config = None
98+
99+
total_iters = 160000
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
_base_ = [
2+
'../_base_/models/stylegan/stylegan3_base.py',
3+
'../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py'
4+
]
5+
6+
synthesis_cfg = {
7+
'type': 'SynthesisNetwork',
8+
'channel_base': 32768,
9+
'channel_max': 512,
10+
'magnitude_ema_beta': 0.999
11+
}
12+
r1_gamma = 6.6 # set by user
13+
d_reg_interval = 16
14+
15+
load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb_20220329_235113-db6c6580.pth' # noqa
16+
# ada settings
17+
aug_kwargs = {
18+
'xflip': 1,
19+
'rotate90': 1,
20+
'xint': 1,
21+
'scale': 1,
22+
'rotate': 1,
23+
'aniso': 1,
24+
'xfrac': 1,
25+
'brightness': 1,
26+
'contrast': 1,
27+
'lumaflip': 1,
28+
'hue': 1,
29+
'saturation': 1
30+
}
31+
32+
model = dict(
33+
type='StaticUnconditionalGAN',
34+
generator=dict(
35+
out_size=1024,
36+
img_channels=3,
37+
rgb2bgr=True,
38+
synthesis_cfg=synthesis_cfg),
39+
discriminator=dict(
40+
type='ADAStyleGAN2Discriminator',
41+
in_size=1024,
42+
input_bgr2rgb=True,
43+
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)),
44+
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
45+
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval))
46+
47+
imgs_root = 'data/metfaces/images/'
48+
data = dict(
49+
samples_per_gpu=4,
50+
train=dict(dataset=dict(imgs_root=imgs_root)),
51+
val=dict(imgs_root=imgs_root))
52+
53+
ema_half_life = 10. # G_smoothing_kimg
54+
55+
ema_kimg = 10
56+
ema_nimg = ema_kimg * 1000
57+
ema_beta = 0.5**(32 / max(ema_nimg, 1e-8))
58+
59+
custom_hooks = [
60+
dict(
61+
type='VisualizeUnconditionalSamples',
62+
output_dir='training_samples',
63+
interval=5000),
64+
dict(
65+
type='ExponentialMovingAverageHook',
66+
module_keys=('generator_ema', ),
67+
interp_mode='lerp',
68+
interp_cfg=dict(momentum=ema_beta),
69+
interval=1,
70+
start_iter=0,
71+
priority='VERY_HIGH')
72+
]
73+
74+
inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl'
75+
metrics = dict(
76+
fid50k=dict(
77+
type='FID',
78+
num_images=50000,
79+
inception_pkl=inception_pkl,
80+
inception_args=dict(type='StyleGAN'),
81+
bgr2rgb=True))
82+
83+
evaluation = dict(
84+
type='GenerativeEvalHook',
85+
interval=dict(milestones=[80000], interval=[10000, 5000]),
86+
metrics=dict(
87+
type='FID',
88+
num_images=50000,
89+
inception_pkl=inception_pkl,
90+
inception_args=dict(type='StyleGAN'),
91+
bgr2rgb=True),
92+
sample_kwargs=dict(sample_model='ema'))
93+
94+
lr_config = None
95+
96+
total_iters = 160000

configs/styleganv3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,14 @@
5656
inception_args=dict(type='StyleGAN'),
5757
bgr2rgb=True))
5858

59-
inception_path = None
6059
evaluation = dict(
6160
type='GenerativeEvalHook',
6261
interval=10000,
6362
metrics=dict(
6463
type='FID',
6564
num_images=50000,
6665
inception_pkl=inception_pkl,
67-
inception_args=dict(type='StyleGAN', inception_path=inception_path),
66+
inception_args=dict(type='StyleGAN'),
6867
bgr2rgb=True),
6968
sample_kwargs=dict(sample_model='ema'))
7069

mmgen/core/evaluation/metrics.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,13 @@ def summary(self):
953953
return self._result_dict
954954

955955
def extract_features(self, images):
956+
"""Extracting image features.
957+
958+
Args:
959+
images (torch.Tensor): Images tensor.
960+
Returns:
961+
torch.Tensor: Vgg16 features of input images.
962+
"""
956963
if self.use_tero_scirpt:
957964
feature = self.vgg16(images, return_features=True)
958965
else:
@@ -1278,6 +1285,17 @@ def summary(self):
12781285
return ppl_score
12791286

12801287
def get_sampler(self, model, batch_size, sample_model):
1288+
"""Get sampler for sampling along the path.
1289+
1290+
Args:
1291+
model (nn.Module): Generative model.
1292+
batch_size (int): Sampling batch size.
1293+
sample_model (str): Which model you want to use. ['ema',
1294+
'orig']. Defaults to 'ema'.
1295+
1296+
Returns:
1297+
Object: A sampler for calculating path length regularization.
1298+
"""
12811299
if sample_model == 'ema':
12821300
generator = model.generator_ema
12831301
else:

0 commit comments

Comments
 (0)