Iterating over task for Continual Learning. #11724
-
Hi everyone, I am new to PyTorch lightening and I am currently trying to implement a continual learning model in PyTorch lightening. I have multiple data loaders for different tasks and I want to train on all of these data loaders. After training on task1 with dataloader1 I want to update the parameters of the model which are going to be trained for task two. To do this, I have an attribute named current_task in my dataloader which decides the dataset from which the samples are generated for the current task. My datamodule looks something like this. class RandSplitCIFAR100DataModule(LightningDataModule):
def __init__(self):
.....
def setup(self, stage: Optional[str] = None):
# load datasets only if they're not loaded already
if not self.data_train and not self.data_val and not self.data_test:
self.data_train = datasets.CIFAR100(self.hparams.data_dir, train=True, transform=self.train_transforms)
self.data_val = datasets.CIFAR100(self.hparams.data_dir, train=False, transform=self.val_transforms)
np.random.seed(self.hparams.seed)
perm = np.random.permutation(self.num_classes)
print(perm)
splits = [
(self.partition_datasetv4(self.data_train, perm[5 * i:5 * (i+1)]),
self.partition_datasetv4(self.data_val, perm[5 * i:5 * (i+1)]),)
for i in range(self.hparams.num_tasks)
]
kwargs = {"num_workers": self.hparams.workers, "pin_memory": self.hparams.pin_memory}
self.loaders = [
(DataLoader(x[0], batch_size=self.hparams.batch_size, shuffle=True, **kwargs),
DataLoader(x[1], batch_size=self.hparams.test_batch_size, shuffle=False, **kwargs),)
for x in splits
]
def update_task(self, i):
self.current_task = i
def train_dataloader(self):
return self.loader[self.current_task][0]
def val_dataloader(self):
return self.loader[self.current_task][1] Now I want to have a training loop that does something like this. for task in range(num_tasks):
self.dataloder.update_task(task)
for n, p in model.named_parameters():
# change parameters to update
for epoch in range(max_epochs):
for batch in dataloader:
.... I am currently not able to figure out how to go about this, I feel confident that lightening should be able to handle such cases but I am just not sure how to go about this. Any help is greatly appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
well, there are multiple ways:
class LitModel(LightningModule):
def on_train_epoch_start(self):
if current_epoch == 0 or (current_epoch + 1) % self.trainer.reload_dataloaders_every_n_epochs == 0:
# update model parameters
max_epochs_n_tasks = max_epochs * n_tasks
trainer = Trainer(max_epochs=max_epochs_n_tasks, reload_dataloaders_every_n_epochs=max_epochs)
model = LitModel()
# inject the update task counter logic inside datamodule
dm = RandSplitCIFAR100DataModule(...)
trainer.fit(model, datamodule=dm)
def init_trainer(...):
trainer = Trainer(max_epochs=max_epochs, ...)
return trainer
datamodule = ...
model = ...
for task in range(num_tasks):
# update params
datamodule.update_task(task)
trainer = init_trainer(...)
trainer.fit(model, datamodule=dm) Although I'd suggest (1), even if your max_epochs differs for each task, it can easily be extended to support that too. |
Beta Was this translation helpful? Give feedback.
well, there are multiple ways: