where to add preprocessing initialization #7307
-
I would like to have a step called before the first training step, and that yet necessitates the dataloader e.g. (mock code) class Scaler(nn.Module):
'''center target data'''
def __init__(self, dims):
self.mean = nn.Parameter(torch.tensor(dims))
self.n = nn.Parameters(torch.zeros(1))
def forward(self, batch):
input, target = batch
if self.training:
self.mean += target.mean(0)
self.n += 1
else:
return input, (target - self.mean)/self.n
class MySystem(pl.LightningModule):
def __init__(self, scaler_dims, model_dims):
self.model = nn.Linear(**model_dims)
self.scaler = Scaler(self.dims).train()
def on_first_epoch(self, dataloader): # <---- not sure where this should live
# learn to scale the dataset
for batch in dataloader:
self.scaler(batch)
def training_step(self, batch, batch_idx):
self.scaler.eval()
input, target = self.scaler(batch)
pred = self.model(input)
loss = F.l1_loss(pred, target)
return loss
dm = MyDataModule()
system = MySystem()
trainer = pl.Trainer()
trainer.fit(system, dm) I'm not clear on how to do this with PL's API: Any advice? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
In this instance would it be simpler to iterate through the dataset outside of Lightning, prior to starting training? |
Beta Was this translation helpful? Give feedback.
-
thanks @ananthsub , I think the point of lightning is to try to keep everything in the same system. going through the doc, I think the best is either
|
Beta Was this translation helpful? Give feedback.
thanks @ananthsub , I think the point of lightning is to try to keep everything in the same system.
going through the doc, I think the best is either
self._prepare_data
(which is called once in distributed, as opposed toself.setup
)