Skip to content

Commit 3583ea0

Browse files
authored
update Flip/Rotate/Resize spatial transforms support MetaTensor (#4429)
1 parent 60a9693 commit 3583ea0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+1467
-1058
lines changed

monai/apps/detection/transforms/dictionary.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def __init__(
395395
self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs)
396396
self.keep_size = keep_size
397397

398-
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
398+
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
399399
d = dict(data)
400400

401401
# zoom box
@@ -408,7 +408,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
408408
box_key,
409409
extra_info={"zoom": self.zoomer.zoom, "src_spatial_size": src_spatial_size, "type": "box_key"},
410410
)
411-
d[box_key] = ZoomBox(zoom=self.zoomer.zoom, keep_size=self.keep_size)(
411+
d[box_key] = ZoomBox(zoom=self.zoomer.zoom, keep_size=self.keep_size)( # type: ignore
412412
d[box_key], src_spatial_size=src_spatial_size
413413
)
414414

@@ -431,7 +431,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
431431

432432
return d
433433

434-
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
434+
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
435435
d = deepcopy(dict(data))
436436

437437
for key in self.key_iterator(d):
@@ -461,7 +461,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
461461
zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"])
462462
src_spatial_size = transform[TraceKeys.EXTRA_INFO]["src_spatial_size"]
463463
box_inverse_transform = ZoomBox(zoom=(1 / zoom).tolist(), keep_size=self.zoomer.keep_size)
464-
d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size)
464+
d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size) # type: ignore
465465

466466
# Remove the applied transform
467467
self.pop_transform(d, key)
@@ -545,7 +545,7 @@ def set_random_state(
545545
self.rand_zoom.set_random_state(seed, state)
546546
return self
547547

548-
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
548+
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
549549
d = dict(data)
550550
first_key: Union[Hashable, List] = self.first_key(d)
551551
if first_key == []:
@@ -568,7 +568,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
568568
box_key,
569569
extra_info={"zoom": self.rand_zoom._zoom, "src_spatial_size": src_spatial_size, "type": "box_key"},
570570
)
571-
d[box_key] = ZoomBox(zoom=self.rand_zoom._zoom, keep_size=self.keep_size)(
571+
d[box_key] = ZoomBox(zoom=self.rand_zoom._zoom, keep_size=self.keep_size)( # type: ignore
572572
d[box_key], src_spatial_size=src_spatial_size
573573
)
574574

@@ -595,7 +595,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
595595

596596
return d
597597

598-
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
598+
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
599599
d = deepcopy(dict(data))
600600

601601
for key in self.key_iterator(d):
@@ -626,7 +626,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
626626
zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"])
627627
src_spatial_size = transform[TraceKeys.EXTRA_INFO]["src_spatial_size"]
628628
box_inverse_transform = ZoomBox(zoom=(1.0 / zoom).tolist(), keep_size=self.rand_zoom.keep_size)
629-
d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size)
629+
d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size) # type: ignore
630630

631631
# Remove the applied transform
632632
self.pop_transform(d, key)
@@ -667,7 +667,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
667667
d = dict(data)
668668

669669
for key in self.image_keys:
670-
d[key] = self.flipper(d[key])
670+
d[key] = self.flipper(d[key]) # type: ignore
671671
self.push_transform(d, key, extra_info={"type": "image_key"})
672672

673673
for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):
@@ -685,7 +685,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
685685

686686
# flip image, copied from monai.transforms.spatial.dictionary.Flipd
687687
if key_type == "image_key":
688-
d[key] = self.flipper(d[key])
688+
d[key] = self.flipper(d[key]) # type: ignore
689689

690690
# flip boxes
691691
if key_type == "box_key":
@@ -743,7 +743,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
743743

744744
for key in self.image_keys:
745745
if self._do_transform:
746-
d[key] = self.flipper(d[key], randomize=False)
746+
d[key] = self.flipper(d[key], randomize=False) # type: ignore
747747
self.push_transform(d, key, extra_info={"type": "image_key"})
748748

749749
for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):
@@ -763,7 +763,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
763763
if transform[TraceKeys.DO_TRANSFORM]:
764764
# flip image, copied from monai.transforms.spatial.dictionary.RandFlipd
765765
if key_type == "image_key":
766-
d[key] = self.flipper(d[key], randomize=False)
766+
d[key] = self.flipper(d[key], randomize=False) # type: ignore
767767

768768
# flip boxes
769769
if key_type == "box_key":
@@ -1271,7 +1271,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable
12711271
self.push_transform(d, key, extra_info={"spatial_size": spatial_size, "type": "box_key"})
12721272

12731273
for key in self.image_keys:
1274-
d[key] = self.img_rotator(d[key])
1274+
d[key] = self.img_rotator(d[key]) # type: ignore
12751275
self.push_transform(d, key, extra_info={"type": "image_key"})
12761276
return d
12771277

@@ -1285,7 +1285,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
12851285

12861286
if key_type == "image_key":
12871287
inverse_transform = Rotate90(num_times_to_rotate, self.img_rotator.spatial_axes)
1288-
d[key] = inverse_transform(d[key])
1288+
d[key] = inverse_transform(d[key]) # type: ignore
12891289
if key_type == "box_key":
12901290
spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"]
12911291
inverse_transform = RotateBox90(num_times_to_rotate, self.box_rotator.spatial_axes)
@@ -1329,7 +1329,7 @@ def __init__(
13291329
super().__init__(self.image_keys + self.box_keys, prob, max_k, spatial_axes, allow_missing_keys)
13301330
self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))
13311331

1332-
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
1332+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: # type: ignore
13331333
self.randomize()
13341334
d = dict(data)
13351335

@@ -1357,11 +1357,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable
13571357

13581358
for key in self.image_keys:
13591359
if self._do_transform:
1360-
d[key] = img_rotator(d[key])
1360+
d[key] = img_rotator(d[key]) # type: ignore
13611361
self.push_transform(d, key, extra_info={"rand_k": self._rand_k, "type": "image_key"})
13621362
return d
13631363

1364-
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
1364+
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: # type: ignore
13651365
d = deepcopy(dict(data))
13661366
if self._rand_k % 4 == 0:
13671367
return d
@@ -1376,7 +1376,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
13761376
# flip image, copied from monai.transforms.spatial.dictionary.RandFlipd
13771377
if key_type == "image_key":
13781378
inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes)
1379-
d[key] = inverse_transform(d[key])
1379+
d[key] = inverse_transform(d[key]) # type: ignore
13801380
if key_type == "box_key":
13811381
spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"]
13821382
inverse_transform = RotateBox90(num_times_to_rotate, self.spatial_axes)

monai/data/image_writer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import warnings
1213
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Sequence, Union
1314

1415
import numpy as np
@@ -269,6 +270,9 @@ def resample_if_needed(
269270
resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype)
270271
output_array = resampler(data_array[None], dst_affine=target_affine, spatial_size=output_spatial_shape)
271272
# convert back at the end
273+
if isinstance(output_array, MetaTensor):
274+
warnings.warn("ignoring the tracking transform info.")
275+
output_array.applied_operations = []
272276
data_array, *_ = convert_data_type(output_array, output_type=orig_type) # type: ignore
273277
affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore
274278
return data_array[0], affine
@@ -764,11 +768,11 @@ def resample_and_clip(
764768
_min, _max = np.min(data), np.max(data)
765769
if len(data.shape) == 3:
766770
data = np.moveaxis(data, -1, 0) # to channel first
767-
data = xform(data) # type: ignore
771+
data = convert_data_type(xform(data), np.ndarray, drop_meta=True)[0] # type: ignore
768772
data = np.moveaxis(data, 0, -1)
769773
else: # (H, W)
770774
data = np.expand_dims(data, 0) # make a channel
771-
data = xform(data)[0] # type: ignore
775+
data = convert_data_type(xform(data), np.ndarray, drop_meta=True)[0][0] # type: ignore
772776
if mode != InterpolateMode.NEAREST:
773777
data = np.clip(data, _min, _max)
774778
return data

monai/data/meta_obj.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def __repr__(self) -> str:
183183
@property
184184
def meta(self) -> dict:
185185
"""Get the meta."""
186-
return self._meta
186+
return self._meta if hasattr(self, "_meta") else self.get_default_meta()
187187

188188
@meta.setter
189189
def meta(self, d) -> None:
@@ -195,7 +195,9 @@ def meta(self, d) -> None:
195195
@property
196196
def applied_operations(self) -> list:
197197
"""Get the applied operations."""
198-
return self._applied_operations
198+
if hasattr(self, "_applied_operations"):
199+
return self._applied_operations
200+
return self.get_default_applied_operations()
199201

200202
@applied_operations.setter
201203
def applied_operations(self, t) -> None:
@@ -215,7 +217,7 @@ def pop_applied_operation(self) -> Any:
215217
@property
216218
def is_batch(self) -> bool:
217219
"""Return whether object is part of batch or not."""
218-
return self._is_batch
220+
return self._is_batch if hasattr(self, "_is_batch") else False
219221

220222
@is_batch.setter
221223
def is_batch(self, val: bool) -> None:

monai/data/meta_tensor.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
183183
# else, handle the `MetaTensor` metadata.
184184
else:
185185
meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values()))
186+
# this is not implemented but the network arch may run into this case:
187+
# if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
188+
# raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")
186189
ret._copy_meta(meta_args)
187190

188191
# If we have a batch of data, then we need to be careful if a slice of
@@ -195,17 +198,17 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
195198
metas = decollate_batch(ret.meta)
196199
# if indexing e.g., `batch[0]`
197200
if func == torch.Tensor.__getitem__:
198-
idx = args[1]
199-
if isinstance(idx, Sequence):
200-
idx = idx[0]
201+
batch_idx = args[1]
202+
if isinstance(batch_idx, Sequence):
203+
batch_idx = batch_idx[0]
201204
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
202205
# first element will be `slice(None, None, None)` and `Ellipsis`,
203206
# respectively. Don't need to do anything with the metadata.
204-
if idx not in (slice(None, None, None), Ellipsis):
205-
meta = metas[idx]
207+
if batch_idx not in (slice(None, None, None), Ellipsis):
208+
meta = metas[batch_idx]
206209
# if using e.g., `batch[0:2]`, then `is_batch` should still be
207210
# `True`. Also re-collate the remaining elements.
208-
if isinstance(meta, list) and len(meta) > 1:
211+
if isinstance(meta, list):
209212
ret.meta = list_data_collate(meta)
210213
# if using e.g., `batch[0]` or `batch[0, 1]`, then return single
211214
# element from batch, and set `is_batch` to `False`.
@@ -243,6 +246,19 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
243246
# we might have 1 or multiple outputs. Might be MetaTensor, might be something
244247
# else (e.g., `__repr__` returns a string).
245248
# Convert to list (if necessary), process, and at end remove list if one was added.
249+
if (
250+
hasattr(torch, "return_types")
251+
and hasattr(func, "__name__")
252+
and hasattr(torch.return_types, func.__name__)
253+
and isinstance(getattr(torch.return_types, func.__name__), type)
254+
and isinstance(ret, getattr(torch.return_types, func.__name__))
255+
):
256+
# for torch.max(torch.tensor(1.0), dim=0), the return type is named-tuple like
257+
out_items = MetaTensor.update_meta(ret, func, args, kwargs)
258+
for idx in range(ret.n_fields):
259+
ret[idx].meta = out_items[idx].meta
260+
ret[idx].applied_operations = out_items[idx].applied_operations
261+
return ret
246262
if isinstance(ret, (str, bytes)) or not isinstance(ret, Sequence):
247263
ret = [ret]
248264
unpack = True

monai/data/png_writer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@
1414
import numpy as np
1515

1616
from monai.transforms.spatial.array import Resize
17-
from monai.utils import InterpolateMode, deprecated, ensure_tuple_rep, look_up_option, optional_import
17+
from monai.utils import (
18+
InterpolateMode,
19+
convert_data_type,
20+
deprecated,
21+
ensure_tuple_rep,
22+
look_up_option,
23+
optional_import,
24+
)
1825

1926
Image, _ = optional_import("PIL", name="Image")
2027

@@ -74,9 +81,9 @@ def write_png(
7481
if scale is not None:
7582
data = np.clip(data, 0.0, 1.0) # png writer only can scale data in range [0, 1]
7683
if scale == np.iinfo(np.uint8).max:
77-
data = (scale * data).astype(np.uint8, copy=False)
84+
data = convert_data_type((scale * data), np.ndarray, dtype=np.uint8, drop_meta=True)[0]
7885
elif scale == np.iinfo(np.uint16).max:
79-
data = (scale * data).astype(np.uint16, copy=False)
86+
data = convert_data_type((scale * data), np.ndarray, dtype=np.uint16, drop_meta=True)[0]
8087
else:
8188
raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]")
8289

monai/data/utils.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,24 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"):
392392
return
393393

394394

395+
def collate_meta_tensor(batch):
396+
"""collate a sequence of meta tensor sequences/dictionaries into
397+
a single batched metatensor or a dictionary of batched metatensor"""
398+
if not isinstance(batch, Sequence):
399+
raise NotImplementedError()
400+
elem_0 = first(batch)
401+
if isinstance(elem_0, MetaObj):
402+
collated = default_collate(batch)
403+
collated.meta = default_collate([i.meta or TraceKeys.NONE for i in batch])
404+
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
405+
collated.is_batch = True
406+
return collated
407+
if isinstance(elem_0, Mapping):
408+
return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0}
409+
# no more recursive search for MetaTensor
410+
return default_collate(batch)
411+
412+
395413
def list_data_collate(batch: Sequence):
396414
"""
397415
Enhancement for PyTorch DataLoader default collate.
@@ -411,19 +429,9 @@ def list_data_collate(batch: Sequence):
411429
for k in elem:
412430
key = k
413431
data_for_batch = [d[key] for d in data]
414-
ret[key] = default_collate(data_for_batch)
415-
if isinstance(ret[key], MetaObj) and all(isinstance(d, MetaObj) for d in data_for_batch):
416-
meta_list = [i.meta or TraceKeys.NONE for i in data_for_batch]
417-
ret[key].meta = default_collate(meta_list)
418-
ops_list = [i.applied_operations or TraceKeys.NONE for i in data_for_batch]
419-
ret[key].applied_operations = default_collate(ops_list)
420-
ret[key].is_batch = True
432+
ret[key] = collate_meta_tensor(data_for_batch)
421433
else:
422-
ret = default_collate(data)
423-
if isinstance(ret, MetaObj) and all(isinstance(d, MetaObj) for d in data):
424-
ret.meta = default_collate([i.meta or TraceKeys.NONE for i in data])
425-
ret.applied_operations = default_collate([i.applied_operations or TraceKeys.NONE for i in data])
426-
ret.is_batch = True
434+
ret = collate_meta_tensor(data)
427435
return ret
428436
except RuntimeError as re:
429437
re_str = str(re)
@@ -550,7 +558,7 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
550558
if isinstance(t, MetaObj):
551559
t.meta = m
552560
t.is_batch = False
553-
for t, m in zip(out_list, decollate_batch(batch.applied_operations)):
561+
for t, m in zip(out_list, batch.applied_operations):
554562
if isinstance(t, MetaObj):
555563
t.applied_operations = m
556564
t.is_batch = False
@@ -848,7 +856,7 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.floa
848856
an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type)
849857
850858
"""
851-
affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0]
859+
affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True, drop_meta=True)[0]
852860
affine_np = affine_np.copy()
853861
if affine_np.ndim != 2:
854862
raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.")

0 commit comments

Comments
 (0)