1313from torch_mimicry .modules import SNLinear
1414from torch_mimicry .modules import GBlock , DBlock , DBlockOptimized
1515
16+
1617#######################
1718# Models
1819#######################
@@ -118,8 +119,6 @@ def __init__(self, ndf=128, loss_type='hinge', **kwargs):
118119 self .l_y = SNLinear (self .ndf , self .num_classes )
119120 nn .init .xavier_uniform_ (self .l_y .weight .data , 1.0 )
120121
121-
122-
123122 def forward (self , x ):
124123 """
125124 Feedforwards a batch of real/fake images and produces a batch of GAN logits,
@@ -141,7 +140,6 @@ def forward(self, x):
141140
142141 return output , output_classes
143142
144-
145143 def _rot_tensor (self , image , deg ):
146144 """
147145 Rotation for pytorch tensors using rotation matrix. Takes in a tensor of (C, H, W shape).
@@ -216,7 +214,7 @@ def train_step(self,
216214 netG ,
217215 optD ,
218216 log_data ,
219- device = None ,
217+ device = None ,
220218 global_step = None ,
221219 ** kwargs ):
222220 """
@@ -272,8 +270,10 @@ def train_step(self,
272270# Data handling objects
273271device = torch .device ('cuda:0' if torch .cuda .is_available () else "cpu" )
274272dataset = mmc .datasets .load_dataset (root = './datasets' , name = 'cifar10' )
275- dataloader = torch .utils .data .DataLoader (
276- dataset , batch_size = 64 , shuffle = True , num_workers = 4 )
273+ dataloader = torch .utils .data .DataLoader (dataset ,
274+ batch_size = 64 ,
275+ shuffle = True ,
276+ num_workers = 4 )
277277
278278# Define models and optimizers
279279netG = SSGANGenerator ().to (device )
@@ -282,28 +282,26 @@ def train_step(self,
282282optG = optim .Adam (netG .parameters (), 2e-4 , betas = (0.0 , 0.9 ))
283283
284284# Start training
285- trainer = mmc .training .Trainer (
286- netD = netD ,
287- netG = netG ,
288- optD = optD ,
289- optG = optG ,
290- n_dis = 2 ,
291- num_steps = 100000 ,
292- dataloader = dataloader ,
293- log_dir = log_dir ,
294- device = device )
285+ trainer = mmc .training .Trainer (netD = netD ,
286+ netG = netG ,
287+ optD = optD ,
288+ optG = optG ,
289+ n_dis = 2 ,
290+ num_steps = 100000 ,
291+ dataloader = dataloader ,
292+ log_dir = log_dir ,
293+ device = device )
295294trainer .train ()
296295
297296##########################
298297# Evaluation
299298##########################
300299# Evaluate fid
301- mmc .metrics .evaluate (
302- metric = 'fid' ,
303- log_dir = log_dir ,
304- netG = netG ,
305- dataset_name = 'cifar10' ,
306- num_real_samples = 10000 ,
307- num_fake_samples = 10000 ,
308- evaluate_step = 100000 ,
309- device = device )
300+ mmc .metrics .evaluate (metric = 'fid' ,
301+ log_dir = log_dir ,
302+ netG = netG ,
303+ dataset_name = 'cifar10' ,
304+ num_real_samples = 10000 ,
305+ num_fake_samples = 10000 ,
306+ evaluate_step = 100000 ,
307+ device = device )
0 commit comments