How to keep track of all outputs during training? #16127
-
Dear lightning discussion channel, thank you for building such an amazing tool and developing environment. I am currently trying to build a simple callback that would keep hold of all outputs and targets of a model during training. I was able to build a rough draft, but I am facing the problem that the images are not fed in a fixed order (shuffle = True), so I wonder if there is a way to find the image index in the dataset. My current approach is defining the following function: def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
...
self.outputs[trainer.current_epoch, batch_size * batch_idx: (batch_idx + 1) * batch_size] = outputs["outputs"]
self.targets[trainer.current_epoch, batch_size * batch_idx: (batch_idx + 1) * batch_size] = outputs["targets"] However, due to shuffling the I am still thinking if there is a simple way to do this. I though about hashing the value of each image, but data augmentation makes such action unreliable. Hopefully you can help me! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
In the end, the solution was simply to modify the dataset class dynamically: def index_return_wrapper(dataset_cls: type) -> type:
class IndexReturnWrapper(dataset_cls):
def __getitem__(self, index: int) -> Tuple[Any, Any, int]:
return super().__getitem__(index) + (index,)
return IndexReturnWrapper |
Beta Was this translation helpful? Give feedback.
In the end, the solution was simply to modify the dataset class dynamically: