Efficient Prediction Writing for 2D/3D Medical Image Segmentation Tasks #17089
Unanswered
ndahiya3
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am working on 2D/3D medical image segmentation tasks with naturally large amounts of training data. I am trying to write a custom prediction writer but running into out of memory issues.
For a 2D image segmentation task, I have written a pseudo3D dataloader which collects all the 2D slices belonging to a single 3D dataset. Then I have list of such dataloaders each comprising of slices belonging to one validation or test 3D dataset. My issue is that I want to collect the outputs of one dataloader, then call prediction writer at the end of each dataloader, save the prediction (and/or compute metrics like 3d dice score or hausdorff distances etc. for each such dataloader) and then clear the outputs. I cannot save/clear outputs at the batch level as I need to save predictions for one dataloader as a 3D dataset. I also cannot cache all the outputs of all the dataloaders (on_validation_epoch_end), as it consumes way too much memory for either GPU or CPU.
for e.g. if I have the following in my segmentation model,
I may have 100k image slices, so there is no way I can cache the 'preds' outputs accumulated over a whole validation (or test/inference) epoch.
I have been able to add a bunch of ugly code in my segmentation model which caches the output internally instead of returning from the validation_step or testing_step. I keep track of the dataloader_index parameter and when it changes. When it changes I call a function to write the predictions for the current dataloader and clear the outputs. But the whole code becomes ugly and tedious. I lose all the benefit of using pytorch lightining which helped clean up the code by reducing training boiler plate code.
Any thoughts or suggestions in this issue are appreciated. I can explain the problem with more code samples if needed.
Beta Was this translation helpful? Give feedback.
All reactions