Skip to content

Commit d35759a

Browse files
authored
3296 adds a flag for invertible on/off, and decouples transform stack API (#3295)
* adds traceable API Signed-off-by: Wenqi Li <[email protected]> * drop peek Signed-off-by: Wenqi Li <[email protected]> * deprecate inversekeys Signed-off-by: Wenqi Li <[email protected]> * inversekeys -> tracekeys Signed-off-by: Wenqi Li <[email protected]> * update trace_key Signed-off-by: Wenqi Li <[email protected]> * update based on comments Signed-off-by: Wenqi Li <[email protected]>
1 parent 6f657a4 commit d35759a

File tree

16 files changed

+285
-176
lines changed

16 files changed

+285
-176
lines changed

monai/data/test_time_augmentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from monai.transforms.inverse_batch_transform import BatchInverseTransform
2525
from monai.transforms.transform import Randomizable
2626
from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode
27-
from monai.utils.enums import CommonKeys, InverseKeys
27+
from monai.utils.enums import CommonKeys, TraceKeys
2828
from monai.utils.module import optional_import
2929

3030
if TYPE_CHECKING:
@@ -168,7 +168,7 @@ def __call__(
168168
ds = Dataset(data_in, self.transform)
169169
dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate)
170170

171-
transform_key = self.orig_key + InverseKeys.KEY_SUFFIX
171+
transform_key = InvertibleTransform.trace_key(self.orig_key)
172172

173173
# create inverter
174174
inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate)
@@ -188,7 +188,7 @@ def __call__(
188188
transform_info = batch_data.get(transform_key, None)
189189
if transform_info is None:
190190
# no invertible transforms, adding dummy info for identity invertible
191-
transform_info = [[InverseKeys.NONE] for _ in range(self.batch_size)]
191+
transform_info = [[TraceKeys.NONE] for _ in range(self.batch_size)]
192192
if self.nearest_interp:
193193
transform_info = convert_inverse_interp_mode(
194194
trans_info=deepcopy(transform_info), mode="nearest", align_corners=None

monai/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@
197197
ThresholdIntensityD,
198198
ThresholdIntensityDict,
199199
)
200-
from .inverse import InvertibleTransform
200+
from .inverse import InvertibleTransform, TraceableTransform
201201
from .inverse_batch_transform import BatchInverseTransform, Decollated
202202
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
203203
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict

monai/transforms/compose.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
apply_transform,
2929
)
3030
from monai.utils import MAX_SEED, ensure_tuple, get_seed
31-
from monai.utils.enums import InverseKeys
31+
from monai.utils.enums import TraceKeys
3232

3333
__all__ = ["Compose", "OneOf"]
3434

@@ -237,7 +237,7 @@ def __call__(self, data):
237237
# if the data is a mapping (dictionary), append the OneOf transform to the end
238238
if isinstance(data, Mapping):
239239
for key in data.keys():
240-
if key + InverseKeys.KEY_SUFFIX in data:
240+
if self.trace_key(key) in data:
241241
self.push_transform(data, key, extra_info={"index": index})
242242
return data
243243

@@ -250,9 +250,9 @@ def inverse(self, data):
250250
# loop until we get an index and then break (since they'll all be the same)
251251
index = None
252252
for key in data.keys():
253-
if key + InverseKeys.KEY_SUFFIX in data:
253+
if self.trace_key(key) in data:
254254
# get the index of the applied OneOf transform
255-
index = self.get_most_recent_transform(data, key)[InverseKeys.EXTRA_INFO]["index"]
255+
index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
256256
# and then remove the OneOf transform
257257
self.pop_transform(data, key)
258258
if index is None:

monai/transforms/croppad/batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from monai.data.utils import list_data_collate
2323
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
2424
from monai.transforms.inverse import InvertibleTransform
25-
from monai.utils.enums import InverseKeys, Method, NumpyPadMode
25+
from monai.utils.enums import Method, NumpyPadMode, TraceKeys
2626

2727
__all__ = ["PadListDataCollate"]
2828

@@ -115,12 +115,12 @@ def inverse(data: dict) -> Dict[Hashable, np.ndarray]:
115115

116116
d = deepcopy(data)
117117
for key in d:
118-
transform_key = str(key) + InverseKeys.KEY_SUFFIX
118+
transform_key = InvertibleTransform.trace_key(key)
119119
if transform_key in d:
120120
transform = d[transform_key][-1]
121121
if not isinstance(transform, Dict):
122122
continue
123-
if transform.get(InverseKeys.CLASS_NAME) == PadListDataCollate.__name__:
123+
if transform.get(TraceKeys.CLASS_NAME) == PadListDataCollate.__name__:
124124
d[key] = CenterSpatialCrop(transform.get("orig_size", -1))(d[key]) # fallback to image size
125125
# remove transform
126126
d[transform_key].pop()

monai/transforms/croppad/dictionary.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
)
5353
from monai.utils import ImageMetaKey as Key
5454
from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, first
55-
from monai.utils.enums import InverseKeys
55+
from monai.utils.enums import TraceKeys
5656

5757
__all__ = [
5858
"PadModeSequence",
@@ -163,7 +163,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
163163
for key in self.key_iterator(d):
164164
transform = self.get_most_recent_transform(d, key)
165165
# Create inverse transform
166-
orig_size = transform[InverseKeys.ORIG_SIZE]
166+
orig_size = transform[TraceKeys.ORIG_SIZE]
167167
if self.padder.method == Method.SYMMETRIC:
168168
current_size = d[key].shape[1:]
169169
roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)]
@@ -239,15 +239,15 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
239239
for key in self.key_iterator(d):
240240
transform = self.get_most_recent_transform(d, key)
241241
# Create inverse transform
242-
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
242+
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
243243
roi_start = np.array(self.padder.spatial_border)
244244
# Need to convert single value to [min1,min2,...]
245245
if roi_start.size == 1:
246246
roi_start = np.full((len(orig_size)), roi_start)
247247
# need to convert [min1,max1,min2,...] to [min1,min2,...]
248248
elif roi_start.size == 2 * orig_size.size:
249249
roi_start = roi_start[::2]
250-
roi_end = np.array(transform[InverseKeys.ORIG_SIZE]) + roi_start
250+
roi_end = np.array(transform[TraceKeys.ORIG_SIZE]) + roi_start
251251

252252
inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end)
253253
# Apply inverse transform
@@ -315,7 +315,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
315315
for key in self.key_iterator(d):
316316
transform = self.get_most_recent_transform(d, key)
317317
# Create inverse transform
318-
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
318+
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
319319
current_size = np.array(d[key].shape[1:])
320320
roi_start = np.floor((current_size - orig_size) / 2)
321321
roi_end = orig_size + roi_start
@@ -384,7 +384,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
384384
for key in self.key_iterator(d):
385385
transform = self.get_most_recent_transform(d, key)
386386
# Create inverse transform
387-
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
387+
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
388388
current_size = np.array(d[key].shape[1:])
389389
# get required pad to start and end
390390
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(self.cropper.slices, orig_size)])
@@ -440,7 +440,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
440440
for key in self.key_iterator(d):
441441
transform = self.get_most_recent_transform(d, key)
442442
# Create inverse transform
443-
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
443+
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
444444
current_size = np.array(d[key].shape[1:])
445445
pad_to_start = np.floor((orig_size - current_size) / 2).astype(int)
446446
# in each direction, if original size is even and current size is odd, += 1
@@ -497,7 +497,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
497497
for key in self.key_iterator(d):
498498
transform = self.get_most_recent_transform(d, key)
499499
# Create inverse transform
500-
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
500+
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
501501
current_size = np.array(d[key].shape[1:])
502502
pad_to_start = np.floor((orig_size - current_size) / 2).astype(int)
503503
# in each direction, if original size is even and current size is odd, += 1
@@ -594,12 +594,12 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
594594
for key in self.key_iterator(d):
595595
transform = self.get_most_recent_transform(d, key)
596596
# Create inverse transform
597-
orig_size = transform[InverseKeys.ORIG_SIZE]
597+
orig_size = transform[TraceKeys.ORIG_SIZE]
598598
random_center = self.random_center
599599
pad_to_start = np.empty((len(orig_size)), dtype=np.int32)
600600
pad_to_end = np.empty((len(orig_size)), dtype=np.int32)
601601
if random_center:
602-
for i, _slice in enumerate(transform[InverseKeys.EXTRA_INFO]["slices"]):
602+
for i, _slice in enumerate(transform[TraceKeys.EXTRA_INFO]["slices"]):
603603
pad_to_start[i] = _slice[0]
604604
pad_to_end[i] = orig_size[i] - _slice[1]
605605
else:
@@ -776,8 +776,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab
776776
cropped = self.cropper(d)
777777
# self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd
778778
for key in self.key_iterator(cropped):
779-
cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore
780-
cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self) # type: ignore
779+
cropped[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore
780+
cropped[self.trace_key(key)][-1][TraceKeys.ID] = id(self) # type: ignore
781781
# add `patch_index` to the meta data
782782
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
783783
meta_key = meta_key or f"{key}_{meta_key_postfix}"
@@ -792,8 +792,8 @@ def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
792792
# We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd
793793
# Need to revert that since we're calling RandSpatialCropd's inverse
794794
for key in self.key_iterator(d):
795-
d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.cropper.__class__.__name__
796-
d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self.cropper)
795+
d[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.cropper.__class__.__name__
796+
d[self.trace_key(key)][-1][TraceKeys.ID] = id(self.cropper)
797797
context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext
798798
with context_manager(self.cropper):
799799
return self.cropper.inverse(d)
@@ -877,9 +877,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
877877
for key in self.key_iterator(d):
878878
transform = self.get_most_recent_transform(d, key)
879879
# Create inverse transform
880-
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
880+
orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE])
881881
cur_size = np.asarray(d[key].shape[1:])
882-
extra_info = transform[InverseKeys.EXTRA_INFO]
882+
extra_info = transform[TraceKeys.EXTRA_INFO]
883883
box_start = np.asarray(extra_info["box_start"])
884884
box_end = np.asarray(extra_info["box_end"])
885885
# first crop the padding part
@@ -999,9 +999,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
999999
for key in self.key_iterator(d):
10001000
transform = self.get_most_recent_transform(d, key)
10011001
# Create inverse transform
1002-
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
1002+
orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE])
10031003
current_size = np.asarray(d[key].shape[1:])
1004-
center = transform[InverseKeys.EXTRA_INFO]["center"]
1004+
center = transform[TraceKeys.EXTRA_INFO]["center"]
10051005
cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
10061006
# get required pad to start and end
10071007
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)])
@@ -1179,9 +1179,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
11791179
for key in self.key_iterator(d):
11801180
transform = self.get_most_recent_transform(d, key)
11811181
# Create inverse transform
1182-
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
1182+
orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE])
11831183
current_size = np.asarray(d[key].shape[1:])
1184-
center = transform[InverseKeys.EXTRA_INFO]["center"]
1184+
center = transform[TraceKeys.EXTRA_INFO]["center"]
11851185
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
11861186
# get required pad to start and end
11871187
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)])
@@ -1364,9 +1364,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
13641364
for key in self.key_iterator(d):
13651365
transform = self.get_most_recent_transform(d, key)
13661366
# Create inverse transform
1367-
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
1367+
orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE])
13681368
current_size = np.asarray(d[key].shape[1:])
1369-
center = transform[InverseKeys.EXTRA_INFO]["center"]
1369+
center = transform[TraceKeys.EXTRA_INFO]["center"]
13701370
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
13711371
# get required pad to start and end
13721372
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)])
@@ -1432,7 +1432,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
14321432
for key in self.key_iterator(d):
14331433
transform = self.get_most_recent_transform(d, key)
14341434
# Create inverse transform
1435-
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
1435+
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
14361436
current_size = np.array(d[key].shape[1:])
14371437
# Unfortunately, we can't just use ResizeWithPadOrCrop with original size because of odd/even rounding.
14381438
# Instead, we first pad any smaller dimensions, and then we crop any larger dimensions.

0 commit comments

Comments
 (0)