trainer.tune causes "No train_dataloader()
method defined." error
#7092
-
Am I right in calling the tune method as follows?
Here is the stacktrace of the error I am getting.
FWIW, here is my model definition class PretrainedResnet50FT(pl.LightningModule):
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--num_classes', type=int, default=2)
parser.add_argument('--lr', type=float, default=1e-3)
return parser
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
image_modules = list(models.resnet50(pretrained=True, progress=False).children())[:-1]
self.resnet = nn.Sequential(*image_modules)
self.classifier = nn.Linear(2048, self.hparams.num_classes)
def forward(self, x):
out = self.resnet(x)
out = torch.flatten(out, 1)
out = self.classifier(out)
return out
def step(self, who, batch, batch_nb):
x, task_labels, slide_id = batch
#Define logits over the task and source embeddings
task_logits = self(x)
#Define loss values over the logits
loss = task_loss = F.cross_entropy(task_logits, task_labels, reduction = "mean")
#Train acc
task_preds = task_logits.argmax(-1)
task_acc = pl.metrics.functional.accuracy(task_preds, task_labels)
#F1
task_f1 = pl.metrics.functional.f1(task_preds, task_labels, num_classes = self.hparams.num_classes)
self.log(who + '_loss', loss)
self.log(who + '_acc', task_acc)
self.log(who + '_f1', task_f1)
return loss
def training_step(self, batch, batch_nb):
# REQUIRED
loss = self.step('train', batch, batch_nb)
return loss
def validation_step(self, batch, batch_nb):
loss = self.step('val', batch, batch_nb)
return loss
def test_step(self, batch, batch_nb):
loss = self.step('test', batch, batch_nb)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr) |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 6 replies
-
Arguments was not correctly passed from |
Beta Was this translation helpful? Give feedback.
-
Hi, I'm on 1.4.1 and have this issue. My code is similar. Am I doing something wrong? |
Beta Was this translation helpful? Give feedback.
-
I'm still having this issue. Followed the example. Can anyone confirm/deny this is still a problem? |
Beta Was this translation helpful? Give feedback.
-
I solved it by applying |
Beta Was this translation helpful? Give feedback.
Arguments was not correctly passed from
tune
tolr_find
. Was solved by this PR: #6784Please upgrade to latest version of lightning :]