Skip to content

Commit c5d4b87

Browse files
committed
2 parents 3513cb4 + 432a0bc commit c5d4b87

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

pl_examples/full_examples/imagenet/imagenet_example.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ def __init__(self, hparams):
3434
self.hparams = hparams
3535
self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained)
3636

37+
def forward(self, x):
38+
return self.model(x)
39+
3740
def training_step(self, batch, batch_idx):
3841
images, target = batch
39-
output = self.model(images)
42+
output = self.forward(images)
4043
loss_val = F.cross_entropy(output, target)
4144
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
4245

@@ -59,7 +62,7 @@ def training_step(self, batch, batch_idx):
5962

6063
def validation_step(self, batch, batch_idx):
6164
images, target = batch
62-
output = self.model(images)
65+
output = self.forward(images)
6366
loss_val = F.cross_entropy(output, target)
6467
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
6568

@@ -132,7 +135,7 @@ def train_dataloader(self):
132135
std=[0.229, 0.224, 0.225],
133136
)
134137

135-
train_dir = os.path.join(self.hparams.data, 'train')
138+
train_dir = os.path.join(self.hparams.data_path, 'train')
136139
train_dataset = datasets.ImageFolder(
137140
train_dir,
138141
transforms.Compose([
@@ -162,7 +165,7 @@ def val_dataloader(self):
162165
mean=[0.485, 0.456, 0.406],
163166
std=[0.229, 0.224, 0.225],
164167
)
165-
val_dir = os.path.join(self.hparams.data, 'val')
168+
val_dir = os.path.join(self.hparams.data_path, 'val')
166169
val_loader = torch.utils.data.DataLoader(
167170
datasets.ImageFolder(val_dir, transforms.Compose([
168171
transforms.Resize(256),
@@ -185,7 +188,7 @@ def add_model_specific_args(parent_parser): # pragma: no cover
185188
' (default: resnet18)')
186189
parser.add_argument('--epochs', default=90, type=int, metavar='N',
187190
help='number of total epochs to run')
188-
parser.add_argument('--seed', type=int, default=None,
191+
parser.add_argument('--seed', type=int, default=42,
189192
help='seed for initializing training. ')
190193
parser.add_argument('-b', '--batch-size', default=256, type=int,
191194
metavar='N',
@@ -214,7 +217,7 @@ def get_args():
214217
help='how many gpus')
215218
parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
216219
help='supports three options dp, ddp, ddp2')
217-
parent_parser.add_argument('--use-16bit', dest='use-16bit', action='store_true',
220+
parent_parser.add_argument('--use-16bit', dest='use_16bit', action='store_true',
218221
help='if true uses 16 bit precision')
219222
parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
220223
help='evaluate model on validation set')

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def train(self):
353353
stop = should_stop and met_min_epochs
354354
if stop:
355355
self.main_progress_bar.close()
356+
model.on_train_end()
356357
return
357358

358359
self.main_progress_bar.close()

0 commit comments

Comments
 (0)