Skip to content

Commit 64c5995

Browse files
committed
Added empty_cache function to clear memory for CPU, CUDA, and MPS devices; updated imports accordingly.
1 parent bdc73c3 commit 64c5995

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mipcandy/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mipcandy.data.geometric import ensure_num_dimensions, orthographic_views, aggregate_orthographic_views, crop
66
from mipcandy.data.inspection import InspectionAnnotation, InspectionAnnotations, load_inspection_annotations, \
77
inspect, ROIDataset, RandomROIDataset
8-
from mipcandy.data.io import fast_save, fast_load, resample_to_isotropic, load_image, save_image
8+
from mipcandy.data.io import fast_save, fast_load, resample_to_isotropic, load_image, save_image, empty_cache
99
from mipcandy.data.sliding_window import do_sliding_window, revert_sliding_window, slide_dataset, \
1010
UnsupervisedSWDataset, SupervisedSWDataset
1111
from mipcandy.data.transform import JointTransform, MONAITransform

mipcandy/data/io.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from gc import collect
12
from math import floor
23
from os import PathLike
34

@@ -56,3 +57,13 @@ def save_image(image: torch.Tensor, path: str | PathLike[str]) -> None:
5657
image = auto_convert(ensure_num_dimensions(image, 3)).to(torch.uint8).permute(1, 2, 0)
5758
return SpITK.WriteImage(SpITK.GetImageFromArray(image.detach().cpu().numpy(), isVector=True), path)
5859
raise NotImplementedError(f"Unsupported file type: {path}")
60+
61+
62+
def empty_cache(device: Device) -> None:
63+
match torch.device(device).type:
64+
case "cpu":
65+
collect()
66+
case "cuda":
67+
torch.cuda.empty_cache()
68+
case "mps":
69+
torch.mps.empty_cache()

0 commit comments

Comments
 (0)