Skip to content

Commit 1a54fb7

Browse files
committed
formatting
1 parent dc1f210 commit 1a54fb7

File tree

2 files changed

+59
-64
lines changed

2 files changed

+59
-64
lines changed

examples/sngan_example.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import torch_mimicry as mmc
44
from torch_mimicry.nets import sngan
55

6-
76
if __name__ == "__main__":
87
# Data handling objects
98
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
109
dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10')
11-
dataloader = torch.utils.data.DataLoader(
12-
dataset, batch_size=64, shuffle=True, num_workers=4)
10+
dataloader = torch.utils.data.DataLoader(dataset,
11+
batch_size=64,
12+
shuffle=True,
13+
num_workers=4)
1314

1415
# Define models and optimizers
1516
netG = sngan.SNGANGenerator32().to(device)
@@ -18,46 +19,42 @@
1819
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))
1920

2021
# Start training
21-
trainer = mmc.training.Trainer(
22-
netD=netD,
23-
netG=netG,
24-
optD=optD,
25-
optG=optG,
26-
n_dis=5,
27-
num_steps=100000,
28-
lr_decay='linear',
29-
dataloader=dataloader,
30-
log_dir='./log/example',
31-
device=device)
22+
trainer = mmc.training.Trainer(netD=netD,
23+
netG=netG,
24+
optD=optD,
25+
optG=optG,
26+
n_dis=5,
27+
num_steps=100000,
28+
lr_decay='linear',
29+
dataloader=dataloader,
30+
log_dir='./log/example',
31+
device=device)
3232
trainer.train()
3333

3434
# Evaluate fid
35-
mmc.metrics.evaluate(
36-
metric='fid',
37-
log_dir='./log/example',
38-
netG=netG,
39-
dataset_name='cifar10',
40-
num_real_samples=50000,
41-
num_fake_samples=50000,
42-
evaluate_step=100000,
43-
device=device)
35+
mmc.metrics.evaluate(metric='fid',
36+
log_dir='./log/example',
37+
netG=netG,
38+
dataset_name='cifar10',
39+
num_real_samples=50000,
40+
num_fake_samples=50000,
41+
evaluate_step=100000,
42+
device=device)
4443

4544
# Evaluate kid
46-
mmc.metrics.evaluate(
47-
metric='kid',
48-
log_dir='./log/example',
49-
netG=netG,
50-
dataset_name='cifar10',
51-
num_subsets=50,
52-
subset_size=1000,
53-
evaluate_step=100000,
54-
device=device)
45+
mmc.metrics.evaluate(metric='kid',
46+
log_dir='./log/example',
47+
netG=netG,
48+
dataset_name='cifar10',
49+
num_subsets=50,
50+
subset_size=1000,
51+
evaluate_step=100000,
52+
device=device)
5553

5654
# Evaluate inception score
57-
mmc.metrics.evaluate(
58-
metric='inception_score',
59-
log_dir='./log/example',
60-
netG=netG,
61-
num_samples=50000,
62-
evaluate_step=100000,
63-
device=device)
55+
mmc.metrics.evaluate(metric='inception_score',
56+
log_dir='./log/example',
57+
netG=netG,
58+
num_samples=50000,
59+
evaluate_step=100000,
60+
device=device)

examples/ssgan_tutorial.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch_mimicry.modules import SNLinear
1414
from torch_mimicry.modules import GBlock, DBlock, DBlockOptimized
1515

16+
1617
#######################
1718
# Models
1819
#######################
@@ -118,8 +119,6 @@ def __init__(self, ndf=128, loss_type='hinge', **kwargs):
118119
self.l_y = SNLinear(self.ndf, self.num_classes)
119120
nn.init.xavier_uniform_(self.l_y.weight.data, 1.0)
120121

121-
122-
123122
def forward(self, x):
124123
"""
125124
Feedforwards a batch of real/fake images and produces a batch of GAN logits,
@@ -141,7 +140,6 @@ def forward(self, x):
141140

142141
return output, output_classes
143142

144-
145143
def _rot_tensor(self, image, deg):
146144
"""
147145
Rotation for pytorch tensors using rotation matrix. Takes in a tensor of (C, H, W shape).
@@ -216,7 +214,7 @@ def train_step(self,
216214
netG,
217215
optD,
218216
log_data,
219-
device=None,
217+
device=None,
220218
global_step=None,
221219
**kwargs):
222220
"""
@@ -272,8 +270,10 @@ def train_step(self,
272270
# Data handling objects
273271
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
274272
dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10')
275-
dataloader = torch.utils.data.DataLoader(
276-
dataset, batch_size=64, shuffle=True, num_workers=4)
273+
dataloader = torch.utils.data.DataLoader(dataset,
274+
batch_size=64,
275+
shuffle=True,
276+
num_workers=4)
277277

278278
# Define models and optimizers
279279
netG = SSGANGenerator().to(device)
@@ -282,28 +282,26 @@ def train_step(self,
282282
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))
283283

284284
# Start training
285-
trainer = mmc.training.Trainer(
286-
netD=netD,
287-
netG=netG,
288-
optD=optD,
289-
optG=optG,
290-
n_dis=2,
291-
num_steps=100000,
292-
dataloader=dataloader,
293-
log_dir=log_dir,
294-
device=device)
285+
trainer = mmc.training.Trainer(netD=netD,
286+
netG=netG,
287+
optD=optD,
288+
optG=optG,
289+
n_dis=2,
290+
num_steps=100000,
291+
dataloader=dataloader,
292+
log_dir=log_dir,
293+
device=device)
295294
trainer.train()
296295

297296
##########################
298297
# Evaluation
299298
##########################
300299
# Evaluate fid
301-
mmc.metrics.evaluate(
302-
metric='fid',
303-
log_dir=log_dir,
304-
netG=netG,
305-
dataset_name='cifar10',
306-
num_real_samples=10000,
307-
num_fake_samples=10000,
308-
evaluate_step=100000,
309-
device=device)
300+
mmc.metrics.evaluate(metric='fid',
301+
log_dir=log_dir,
302+
netG=netG,
303+
dataset_name='cifar10',
304+
num_real_samples=10000,
305+
num_fake_samples=10000,
306+
evaluate_step=100000,
307+
device=device)

0 commit comments

Comments
 (0)