Help understanding data module error #8298
-
Hi there, I am trying to implement a data module however I keep getting an error that I cannot understand. I normally setup my data module as: class dataModule(pl.LightningDataModule):
def __init__(self, batch_size, csv_file, data_dir):
super().__init__()
self.csv_file = csv_file
self.data_dir = data_dir
self.batch_size = batch_size
self.preprocess = None
self.transform = None
self.train_set = None
self.val_set = None
self.test_set = None
def get_augmentation_transform(self):
augment = tio.Compose([
tio.RandomAffine(),
tio.RandomFlip(p = 0.25),
tio.RandomGamma(p=0.25),
tio.RandomNoise(p=0.25),
tio.RandomMotion(p=0.1),
tio.RandomBiasField(p=0.25),
])
return augment
def setup(self, stage=None):
subjList = fmriDataset(csv_file = self.csv_file,
root_dir = self.data_dir)
train_size, val_size = int(0.7 * len(subjList)), int(0.2 * len(subjList))
test_size = len(subjList) - train_size - val_size
if stage == 'fit' or stage is None:
self.train_dataset, self.val_dataset, _ = torch.utils.data.random_split(subjList, [train_size, val_size, test_size])
if stage == 'test' or stage is None:
_, _, self.test_dataset = torch.utils.data.random_split(subjList, [train_size, val_size, test_size])
augment = self.get_augmentation_transform()
self.train_set = tio.SubjectsDataset(self.train_dataset, transform=augment)
self.val_set = tio.SubjectsDataset(self.val_dataset, transform=None)
self.test_set = tio.SubjectsDataset(self.test_dataset, transform=None)
def train_dataloader(self):
return DataLoader(self.train_set, self.batch_size, shuffle=True, num_workers=27)
def val_dataloader(self):
return DataLoader(self.val_set, self.batch_size, num_workers=27)
def test_dataloader(self):
return DataLoader(self.test_set, self.batch_size, num_workers=27) However, when I call AttributeError: 'dataModule' object has no attribute 'test_dataset'
if stage == 'fit' or stage is None:
self.train_dataset, self.val_dataset, _ = torch.utils.data.random_split(subjList, [train_size, val_size, test_size])
if stage == 'test' or stage is None:
_, _, self.test_dataset = torch.utils.data.random_split(subjList, [train_size, val_size, test_size]) to a single line: self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(subjList, [train_size, val_size, test_size]) everything works. Is there something basic that I have missed?
And I initialise the data module/model/trainer with: data = dataModule(data_dir = '/home/data/', csv_dir = '/home/scanList.csv', batch_size = 24)
model = cnnRnnClassifier()
early_stop_callback = Earlystopping(
monitor = 'val_loss',
min_delta = 1e-4,
patience = 10,
Verbose = True,
mode = 'min')
trainer = Trainer(
gpus = 1,
fast_dev_run = False,
max_epochs = 100,
weights_summary = 'full',
callbacks = [early_stop_callback],
auto_lr_find = True,
precision = 16)
trainer.tune(model = model, datamodule = data)
trainer.fit(model = model, datamodule = data)
trainer.test(model = model, datamodule = data) Thanks in advance for your help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The problem is the combination of this line: if stage == 'test' or stage is None:
_, _, self.test_dataset = torch.utils.data.random_split(subjList, [train_size, val_size, test_size]) and this line: self.test_set = tio.SubjectsDataset(self.test_dataset, transform=None) as you can see, |
Beta Was this translation helpful? Give feedback.
The problem is the combination of this line:
and this line:
as you can see,
self.test_dataset
is only defined if the condition above applies. With this hint you should be able to figure it out now. Let me know :)