How to implement channels last memory format callback #7349
-
Hi there Pytorch docs recommends using channels last when training vision models in mixed precision. To enable, you need to do two changes:
My problem is in step 2. I don't find any PyTorch lightning hook that allows me to make this modification to the batch :/. The only options left are to add it as data transforms (that must be used in conjunction with the callback) or doing all channel last related logic inside the LightningModule. I would prefer to avoid this last solution as it could clutter the LightningModule with unnecessary code. Do you know a to do step 2 in a callback? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You can use any of the following if done inside the LightningModule: batch = self.on_before_batch_transfer(batch, dataloader_idx)
batch = self.transfer_batch_to_device(batch, device)
batch = self.on_after_batch_transfer(batch, dataloader_idx) If you really need to do it in the Callback, I guess you could use |
Beta Was this translation helpful? Give feedback.
You can use any of the following if done inside the LightningModule:
If you really need to do it in the Callback, I guess you could use
on_train_batch_start
since the modification is in-place. But I wouldn't recommend it.