Skip to content

Commit eeb48ce

Browse files
HarshSharma12neggertwilliamFalcon
committed
implement forward and update args (#709) (#724)
* implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * Fix failing GPU tests (#722) * Fix distributed_backend=None test We now throw a warning instead of an exception. Update test to reflect this. * Fix test_tube logger close when debug=True * Clean docs (#725) * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * set auto dp if no backend * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * flake 8 * Update theme_variables.jinja * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * use self.forward for val step (#709) Co-authored-by: Nic Eggert <[email protected]> Co-authored-by: William Falcon <[email protected]>
1 parent f8d9f8f commit eeb48ce

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

pl_examples/full_examples/imagenet/imagenet_example.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ def __init__(self, hparams):
3434
self.hparams = hparams
3535
self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained)
3636

37+
def forward(self, x):
38+
return self.model(x)
39+
3740
def training_step(self, batch, batch_idx):
3841
images, target = batch
39-
output = self.model(images)
42+
output = self.forward(images)
4043
loss_val = F.cross_entropy(output, target)
4144
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
4245

@@ -59,7 +62,7 @@ def training_step(self, batch, batch_idx):
5962

6063
def validation_step(self, batch, batch_idx):
6164
images, target = batch
62-
output = self.model(images)
65+
output = self.forward(images)
6366
loss_val = F.cross_entropy(output, target)
6467
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
6568

@@ -132,7 +135,7 @@ def train_dataloader(self):
132135
std=[0.229, 0.224, 0.225],
133136
)
134137

135-
train_dir = os.path.join(self.hparams.data, 'train')
138+
train_dir = os.path.join(self.hparams.data_path, 'train')
136139
train_dataset = datasets.ImageFolder(
137140
train_dir,
138141
transforms.Compose([
@@ -162,7 +165,7 @@ def val_dataloader(self):
162165
mean=[0.485, 0.456, 0.406],
163166
std=[0.229, 0.224, 0.225],
164167
)
165-
val_dir = os.path.join(self.hparams.data, 'val')
168+
val_dir = os.path.join(self.hparams.data_path, 'val')
166169
val_loader = torch.utils.data.DataLoader(
167170
datasets.ImageFolder(val_dir, transforms.Compose([
168171
transforms.Resize(256),
@@ -185,7 +188,7 @@ def add_model_specific_args(parent_parser): # pragma: no cover
185188
' (default: resnet18)')
186189
parser.add_argument('--epochs', default=90, type=int, metavar='N',
187190
help='number of total epochs to run')
188-
parser.add_argument('--seed', type=int, default=None,
191+
parser.add_argument('--seed', type=int, default=42,
189192
help='seed for initializing training. ')
190193
parser.add_argument('-b', '--batch-size', default=256, type=int,
191194
metavar='N',
@@ -214,7 +217,7 @@ def get_args():
214217
help='how many gpus')
215218
parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
216219
help='supports three options dp, ddp, ddp2')
217-
parent_parser.add_argument('--use-16bit', dest='use-16bit', action='store_true',
220+
parent_parser.add_argument('--use-16bit', dest='use_16bit', action='store_true',
218221
help='if true uses 16 bit precision')
219222
parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
220223
help='evaluate model on validation set')

0 commit comments

Comments
 (0)