Skip to content

Commit df51ef4

Browse files
enhance copy in metatensor (#4506)
* shallow copy in meta structure Signed-off-by: Wenqi Li <[email protected]> * is batch no deepcopy Signed-off-by: Wenqi Li <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 46dc2ec commit df51ef4

File tree

3 files changed

+66
-60
lines changed

3 files changed

+66
-60
lines changed

monai/data/meta_obj.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
from __future__ import annotations
1313

14+
import itertools
1415
from copy import deepcopy
15-
from typing import Any, Callable, Sequence
16+
from typing import Any, Iterable
1617

1718
from monai.utils.enums import TraceKeys
1819

@@ -74,86 +75,88 @@ class MetaObj:
7475
"""
7576

7677
def __init__(self):
77-
self._meta: dict = self.get_default_meta()
78-
self._applied_operations: list = self.get_default_applied_operations()
78+
self._meta: dict = MetaObj.get_default_meta()
79+
self._applied_operations: list = MetaObj.get_default_applied_operations()
7980
self._is_batch: bool = False
8081

8182
@staticmethod
82-
def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]:
83+
def flatten_meta_objs(*args: Iterable):
8384
"""
84-
Recursively flatten input and return all instances of `MetaObj` as a single
85-
list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and
85+
Recursively flatten input and yield all instances of `MetaObj`.
86+
This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and
8687
their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type
8788
`MetaObj`.
8889
8990
Args:
90-
args: Sequence of inputs to be flattened.
91+
args: Iterables of inputs to be flattened.
9192
Returns:
9293
list of nested `MetaObj` from input.
9394
"""
94-
out = []
95-
for a in args:
95+
for a in itertools.chain(*args):
9696
if isinstance(a, (list, tuple)):
97-
out += MetaObj.flatten_meta_objs(a)
97+
yield from MetaObj.flatten_meta_objs(a)
9898
elif isinstance(a, MetaObj):
99-
out.append(a)
100-
return out
99+
yield a
101100

102-
def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
101+
def _copy_attr(self, attributes: list[str], input_objs, defaults: list, deep_copy: bool) -> None:
103102
"""
104-
Copy an attribute from the first in a list of `MetaObj`. In the case of
103+
Copy attributes from the first in a list of `MetaObj`. In the case of
105104
`torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so
106105
check them all. Copy the first to `self`.
107106
108107
We also perform a deep copy of the data if desired.
109108
110109
Args:
111-
attribute: string corresponding to attribute to be copied (e.g., `meta`).
112-
input_objs: List of `MetaObj`. We'll copy the attribute from the first one
110+
attributes: a sequence of strings corresponding to attributes to be copied (e.g., `['meta']`).
111+
input_objs: an iterable of `MetaObj` instances. We'll copy the attribute from the first one
113112
that contains that particular attribute.
114-
default_fn: If none of `input_objs` have the attribute that we're
115-
interested in, then use this default function (e.g., `lambda: {}`.)
116-
deep_copy: Should the attribute be deep copied? See `_copy_meta`.
113+
defaults: If none of `input_objs` have the attribute that we're
114+
interested in, then use this default value/function (e.g., `lambda: {}`.)
115+
the defaults must be the same length as `attributes`.
116+
deep_copy: whether to deep copy the corresponding attribute.
117117
118118
Returns:
119119
Returns `None`, but `self` should be updated to have the copied attribute.
120120
"""
121-
attributes = [getattr(i, attribute) for i in input_objs if hasattr(i, attribute)]
122-
if len(attributes) > 0:
123-
val = attributes[0]
124-
if deep_copy:
125-
val = deepcopy(val)
126-
setattr(self, attribute, val)
127-
else:
128-
setattr(self, attribute, default_fn())
129-
130-
def _copy_meta(self, input_objs: list[MetaObj]) -> None:
121+
found = [False] * len(attributes)
122+
for i, (idx, a) in itertools.product(input_objs, enumerate(attributes)):
123+
if not found[idx] and hasattr(i, a):
124+
setattr(self, a, deepcopy(getattr(i, a)) if deep_copy else getattr(i, a))
125+
found[idx] = True
126+
if all(found):
127+
return
128+
for a, f, d in zip(attributes, found, defaults):
129+
if not f:
130+
setattr(self, a, d() if callable(defaults) else d)
131+
return
132+
133+
def _copy_meta(self, input_objs, deep_copy=False) -> None:
131134
"""
132-
Copy metadata from a list of `MetaObj`. For a given attribute, we copy the
135+
Copy metadata from an iterable of `MetaObj` instances. For a given attribute, we copy the
133136
adjunct data from the first element in the list containing that attribute.
134137
135-
If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g.,
136-
`a+=1`), then don't.
137-
138138
Args:
139139
input_objs: list of `MetaObj` to copy data from.
140140
141141
"""
142-
id_in = id(input_objs[0]) if len(input_objs) > 0 else None
143-
deep_copy = id(self) != id_in
144-
self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy)
145-
self._copy_attr("applied_operations", input_objs, self.get_default_applied_operations, deep_copy)
146-
self.is_batch = input_objs[0].is_batch if len(input_objs) > 0 else False
142+
self._copy_attr(
143+
["meta", "applied_operations"],
144+
input_objs,
145+
[MetaObj.get_default_meta(), MetaObj.get_default_applied_operations()],
146+
deep_copy,
147+
)
147148

148-
def get_default_meta(self) -> dict:
149+
@staticmethod
150+
def get_default_meta() -> dict:
149151
"""Get the default meta.
150152
151153
Returns:
152154
default metadata.
153155
"""
154156
return {}
155157

156-
def get_default_applied_operations(self) -> list:
158+
@staticmethod
159+
def get_default_applied_operations() -> list:
157160
"""Get the default applied operations.
158161
159162
Returns:
@@ -183,28 +186,28 @@ def __repr__(self) -> str:
183186
@property
184187
def meta(self) -> dict:
185188
"""Get the meta."""
186-
return self._meta if hasattr(self, "_meta") else self.get_default_meta()
189+
return self._meta if hasattr(self, "_meta") else MetaObj.get_default_meta()
187190

188191
@meta.setter
189192
def meta(self, d) -> None:
190193
"""Set the meta."""
191194
if d == TraceKeys.NONE:
192-
self._meta = self.get_default_meta()
195+
self._meta = MetaObj.get_default_meta()
193196
self._meta = d
194197

195198
@property
196199
def applied_operations(self) -> list:
197200
"""Get the applied operations."""
198201
if hasattr(self, "_applied_operations"):
199202
return self._applied_operations
200-
return self.get_default_applied_operations()
203+
return MetaObj.get_default_applied_operations()
201204

202205
@applied_operations.setter
203206
def applied_operations(self, t) -> None:
204207
"""Set the applied operations."""
205208
if t == TraceKeys.NONE:
206209
# received no operations when decollating a batch
207-
self._applied_operations = self.get_default_applied_operations()
210+
self._applied_operations = MetaObj.get_default_applied_operations()
208211
return
209212
self._applied_operations = t
210213

monai/data/meta_tensor.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import warnings
1515
from copy import deepcopy
16-
from typing import Any, Callable, Sequence
16+
from typing import Any, Sequence
1717

1818
import torch
1919

@@ -126,19 +126,20 @@ def __init__(
126126
elif isinstance(x, MetaTensor):
127127
self.applied_operations = x.applied_operations
128128
else:
129-
self.applied_operations = self.get_default_applied_operations()
129+
self.applied_operations = MetaObj.get_default_applied_operations()
130130

131131
# if we are creating a new MetaTensor, then deep copy attributes
132132
if isinstance(x, torch.Tensor) and not isinstance(x, MetaTensor):
133133
self.meta = deepcopy(self.meta)
134134
self.applied_operations = deepcopy(self.applied_operations)
135135
self.affine = self.affine.to(self.device)
136136

137-
def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
138-
super()._copy_attr(attribute, input_objs, default_fn, deep_copy)
139-
val = getattr(self, attribute)
140-
if isinstance(val, torch.Tensor):
141-
setattr(self, attribute, val.to(self.device))
137+
def _copy_attr(self, attributes: list[str], input_objs, defaults: list, deep_copy: bool) -> None:
138+
super()._copy_attr(attributes, input_objs, defaults, deep_copy)
139+
for a in attributes:
140+
val = getattr(self, a)
141+
if isinstance(val, torch.Tensor):
142+
setattr(self, a, val.to(self.device))
142143

143144
@staticmethod
144145
def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
@@ -173,6 +174,7 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
173174
"""
174175
out = []
175176
metas = None
177+
is_batch = any(x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch"))
176178
for idx, ret in enumerate(rets):
177179
# if not `MetaTensor`, nothing to do.
178180
if not isinstance(ret, MetaTensor):
@@ -182,20 +184,18 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
182184
ret = ret.as_tensor()
183185
# else, handle the `MetaTensor` metadata.
184186
else:
185-
meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values()))
186-
# this is not implemented but the network arch may run into this case:
187+
meta_args = MetaObj.flatten_meta_objs(args, kwargs.values()) # type: ignore
188+
ret._copy_meta(meta_args, deep_copy=not is_batch)
189+
ret.is_batch = is_batch
190+
# the following is not implemented but the network arch may run into this case:
187191
# if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
188192
# raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")
189-
ret._copy_meta(meta_args)
190193

191194
# If we have a batch of data, then we need to be careful if a slice of
192195
# the data is returned. Depending on how the data are indexed, we return
193196
# some or all of the metadata, and the return object may or may not be a
194197
# batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
195-
if ret.is_batch:
196-
# only decollate metadata once
197-
if metas is None:
198-
metas = decollate_batch(ret.meta)
198+
if is_batch:
199199
# if indexing e.g., `batch[0]`
200200
if func == torch.Tensor.__getitem__:
201201
batch_idx = args[1]
@@ -205,6 +205,9 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
205205
# first element will be `slice(None, None, None)` and `Ellipsis`,
206206
# respectively. Don't need to do anything with the metadata.
207207
if batch_idx not in (slice(None, None, None), Ellipsis):
208+
# only decollate metadata once
209+
if metas is None:
210+
metas = decollate_batch(ret.meta)
208211
meta = metas[batch_idx]
209212
# if using e.g., `batch[0:2]`, then `is_batch` should still be
210213
# `True`. Also re-collate the remaining elements.
@@ -226,6 +229,8 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
226229
else:
227230
dim = 0
228231
if dim == 0:
232+
if metas is None:
233+
metas = decollate_batch(ret.meta)
229234
ret.meta = metas[idx]
230235
ret.is_batch = False
231236

tests/test_integration_fast_train.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
Compose,
3535
CropForegroundd,
3636
EnsureChannelFirstd,
37-
EnsureTyped,
3837
FgBgToIndicesd,
3938
LoadImaged,
4039
RandAffined,
@@ -94,7 +93,6 @@ def test_train_timing(self):
9493
# and cache them to accelerate training
9594
FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"),
9695
# move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
97-
EnsureTyped(keys=["image", "label"], drop_meta=True),
9896
ToDeviced(keys=["image", "label"], device=device),
9997
# randomly crop out patch samples from big
10098
# image based on pos / neg ratio

0 commit comments

Comments
 (0)