-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Closed
Copy link
Labels
loopsRelated to the Loop APIRelated to the Loop APIwaiting on authorWaiting on user action, correction, or updateWaiting on user action, correction, or update
Description
Bug description
I use trainer.fit(model, datamodule=dm)
to start training.
"dm" is an object whose class inherited from pl.LightningDataModule
, and in the class, I override the function:
def train_dataloader(self):
train_dataset = MixedBatchMultiviewDataset(self.args, self.tokenizer,
known_exs=self.known_train,
unknown_exs=self.unknown_train,
feature=self.args.feature)
train_dataloader = DataLoader(train_dataset,
batch_size = self.args.train_batch_size,
shuffle=True, num_workers=self.args.num_workers,
pin_memory=True, collate_fn=self.collate_batch_feat)
return train_dataloader
at the model's hook on_train_epoch_start
, I update the dataset:
train_dl = self.trainer.train_dataloader
train_dl.dataset.update_pseudo_labels(uid2pl)
loop = self.trainer.fit_loop
loop._combined_loader = None
loop.setup_data()
in the training_step
, the batch data is still old data, but trainer.train_dataloader.dataset
is new:
def training_step(self, batch: List[Dict[str, torch.Tensor]], batch_idx: int):
self.mv_model._on_train_batch_start()
logger.info(self.trainer.train_dataloader.dataset.unknown_feats) # new
logger.info(batch) # old
What version are you seeing the problem on?
v2.3
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
Metadata
Metadata
Assignees
Labels
loopsRelated to the Loop APIRelated to the Loop APIwaiting on authorWaiting on user action, correction, or updateWaiting on user action, correction, or update