@@ -34,9 +34,12 @@ def __init__(self, hparams):
3434 self .hparams = hparams
3535 self .model = models .__dict__ [self .hparams .arch ](pretrained = self .hparams .pretrained )
3636
37+ def forward (self , x ):
38+ return self .model (x )
39+
3740 def training_step (self , batch , batch_idx ):
3841 images , target = batch
39- output = self .model (images )
42+ output = self .forward (images )
4043 loss_val = F .cross_entropy (output , target )
4144 acc1 , acc5 = self .__accuracy (output , target , topk = (1 , 5 ))
4245
@@ -59,7 +62,7 @@ def training_step(self, batch, batch_idx):
5962
6063 def validation_step (self , batch , batch_idx ):
6164 images , target = batch
62- output = self .model (images )
65+ output = self .forward (images )
6366 loss_val = F .cross_entropy (output , target )
6467 acc1 , acc5 = self .__accuracy (output , target , topk = (1 , 5 ))
6568
@@ -132,7 +135,7 @@ def train_dataloader(self):
132135 std = [0.229 , 0.224 , 0.225 ],
133136 )
134137
135- train_dir = os .path .join (self .hparams .data , 'train' )
138+ train_dir = os .path .join (self .hparams .data_path , 'train' )
136139 train_dataset = datasets .ImageFolder (
137140 train_dir ,
138141 transforms .Compose ([
@@ -162,7 +165,7 @@ def val_dataloader(self):
162165 mean = [0.485 , 0.456 , 0.406 ],
163166 std = [0.229 , 0.224 , 0.225 ],
164167 )
165- val_dir = os .path .join (self .hparams .data , 'val' )
168+ val_dir = os .path .join (self .hparams .data_path , 'val' )
166169 val_loader = torch .utils .data .DataLoader (
167170 datasets .ImageFolder (val_dir , transforms .Compose ([
168171 transforms .Resize (256 ),
@@ -185,7 +188,7 @@ def add_model_specific_args(parent_parser): # pragma: no cover
185188 ' (default: resnet18)' )
186189 parser .add_argument ('--epochs' , default = 90 , type = int , metavar = 'N' ,
187190 help = 'number of total epochs to run' )
188- parser .add_argument ('--seed' , type = int , default = None ,
191+ parser .add_argument ('--seed' , type = int , default = 42 ,
189192 help = 'seed for initializing training. ' )
190193 parser .add_argument ('-b' , '--batch-size' , default = 256 , type = int ,
191194 metavar = 'N' ,
@@ -214,7 +217,7 @@ def get_args():
214217 help = 'how many gpus' )
215218 parent_parser .add_argument ('--distributed-backend' , type = str , default = 'dp' , choices = ('dp' , 'ddp' , 'ddp2' ),
216219 help = 'supports three options dp, ddp, ddp2' )
217- parent_parser .add_argument ('--use-16bit' , dest = 'use-16bit ' , action = 'store_true' ,
220+ parent_parser .add_argument ('--use-16bit' , dest = 'use_16bit ' , action = 'store_true' ,
218221 help = 'if true uses 16 bit precision' )
219222 parent_parser .add_argument ('-e' , '--evaluate' , dest = 'evaluate' , action = 'store_true' ,
220223 help = 'evaluate model on validation set' )
0 commit comments