Instantiate data augmentations through CLI #12424
-
Hi, import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.utilities.cli import LightningCLI, MODEL_REGISTRY
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform
class DummyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.fc = nn.Linear(32*32*3, 10)
self.loss_fn = nn.MSELoss()
def shared_step(self, batch, batch_idx):
x, y = batch[0][:2]
z1 = self.fc(x.reshape(x.size(0), -1))
z2 = self.fc(y.reshape(y.size(0), -1))
return self.loss_fn(z1, z2)
def training_step(self, batch, batch_idx):
return self.shared_step(batch, batch_idx)
def validation_step(self, batch, batch_idx):
return self.shared_step(batch, batch_idx)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
# THE COMMENTED LINES BELOW RUN PERFECTLY
# trainer = pl.Trainer()
# model = DummyModel()
# dm = CIFAR10DataModule()
# dm.train_transforms = SimCLRTrainDataTransform(32)
# dm.val_transforms = SimCLREvalDataTransform(32)
# trainer.fit(model, dm)
cli = LightningCLI(DummyModel, CIFAR10DataModule, run=False)
# not instantiated!
print(cli.config_init.data.train_transforms) and here is my config.yaml: data:
train_transforms:
class_path: pl_bolts.models.self_supervised.simclr.SimCLRTrainDataTransform
init_args:
input_height: 32
val_transforms:
class_path: pl_bolts.models.self_supervised.simclr.SimCLREvalDataTransform
init_args:
input_height: 32 To run the code, I run the following command:
and here is the error I get:
What I understand from the error is that |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
cc @carmocca |
Beta Was this translation helpful? Give feedback.
cc @carmocca