|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
| 14 | +import itertools |
14 | 15 | from copy import deepcopy |
15 | | -from typing import Any, Callable, Sequence |
| 16 | +from typing import Any, Iterable |
16 | 17 |
|
17 | 18 | from monai.utils.enums import TraceKeys |
18 | 19 |
|
@@ -74,86 +75,88 @@ class MetaObj: |
74 | 75 | """ |
75 | 76 |
|
76 | 77 | def __init__(self): |
77 | | - self._meta: dict = self.get_default_meta() |
78 | | - self._applied_operations: list = self.get_default_applied_operations() |
| 78 | + self._meta: dict = MetaObj.get_default_meta() |
| 79 | + self._applied_operations: list = MetaObj.get_default_applied_operations() |
79 | 80 | self._is_batch: bool = False |
80 | 81 |
|
81 | 82 | @staticmethod |
82 | | - def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: |
| 83 | + def flatten_meta_objs(*args: Iterable): |
83 | 84 | """ |
84 | | - Recursively flatten input and return all instances of `MetaObj` as a single |
85 | | - list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and |
| 85 | + Recursively flatten input and yield all instances of `MetaObj`. |
| 86 | + This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and |
86 | 87 | their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type |
87 | 88 | `MetaObj`. |
88 | 89 |
|
89 | 90 | Args: |
90 | | - args: Sequence of inputs to be flattened. |
| 91 | + args: Iterables of inputs to be flattened. |
91 | 92 | Returns: |
92 | 93 | list of nested `MetaObj` from input. |
93 | 94 | """ |
94 | | - out = [] |
95 | | - for a in args: |
| 95 | + for a in itertools.chain(*args): |
96 | 96 | if isinstance(a, (list, tuple)): |
97 | | - out += MetaObj.flatten_meta_objs(a) |
| 97 | + yield from MetaObj.flatten_meta_objs(a) |
98 | 98 | elif isinstance(a, MetaObj): |
99 | | - out.append(a) |
100 | | - return out |
| 99 | + yield a |
101 | 100 |
|
102 | | - def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: |
| 101 | + def _copy_attr(self, attributes: list[str], input_objs, defaults: list, deep_copy: bool) -> None: |
103 | 102 | """ |
104 | | - Copy an attribute from the first in a list of `MetaObj`. In the case of |
| 103 | + Copy attributes from the first in a list of `MetaObj`. In the case of |
105 | 104 | `torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so |
106 | 105 | check them all. Copy the first to `self`. |
107 | 106 |
|
108 | 107 | We also perform a deep copy of the data if desired. |
109 | 108 |
|
110 | 109 | Args: |
111 | | - attribute: string corresponding to attribute to be copied (e.g., `meta`). |
112 | | - input_objs: List of `MetaObj`. We'll copy the attribute from the first one |
| 110 | + attributes: a sequence of strings corresponding to attributes to be copied (e.g., `['meta']`). |
| 111 | + input_objs: an iterable of `MetaObj` instances. We'll copy the attribute from the first one |
113 | 112 | that contains that particular attribute. |
114 | | - default_fn: If none of `input_objs` have the attribute that we're |
115 | | - interested in, then use this default function (e.g., `lambda: {}`.) |
116 | | - deep_copy: Should the attribute be deep copied? See `_copy_meta`. |
| 113 | + defaults: If none of `input_objs` have the attribute that we're |
| 114 | + interested in, then use this default value/function (e.g., `lambda: {}`.) |
| 115 | + the defaults must be the same length as `attributes`. |
| 116 | + deep_copy: whether to deep copy the corresponding attribute. |
117 | 117 |
|
118 | 118 | Returns: |
119 | 119 | Returns `None`, but `self` should be updated to have the copied attribute. |
120 | 120 | """ |
121 | | - attributes = [getattr(i, attribute) for i in input_objs if hasattr(i, attribute)] |
122 | | - if len(attributes) > 0: |
123 | | - val = attributes[0] |
124 | | - if deep_copy: |
125 | | - val = deepcopy(val) |
126 | | - setattr(self, attribute, val) |
127 | | - else: |
128 | | - setattr(self, attribute, default_fn()) |
129 | | - |
130 | | - def _copy_meta(self, input_objs: list[MetaObj]) -> None: |
| 121 | + found = [False] * len(attributes) |
| 122 | + for i, (idx, a) in itertools.product(input_objs, enumerate(attributes)): |
| 123 | + if not found[idx] and hasattr(i, a): |
| 124 | + setattr(self, a, deepcopy(getattr(i, a)) if deep_copy else getattr(i, a)) |
| 125 | + found[idx] = True |
| 126 | + if all(found): |
| 127 | + return |
| 128 | + for a, f, d in zip(attributes, found, defaults): |
| 129 | + if not f: |
| 130 | + setattr(self, a, d() if callable(defaults) else d) |
| 131 | + return |
| 132 | + |
| 133 | + def _copy_meta(self, input_objs, deep_copy=False) -> None: |
131 | 134 | """ |
132 | | - Copy metadata from a list of `MetaObj`. For a given attribute, we copy the |
| 135 | + Copy metadata from an iterable of `MetaObj` instances. For a given attribute, we copy the |
133 | 136 | adjunct data from the first element in the list containing that attribute. |
134 | 137 |
|
135 | | - If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g., |
136 | | - `a+=1`), then don't. |
137 | | -
|
138 | 138 | Args: |
139 | 139 | input_objs: list of `MetaObj` to copy data from. |
140 | 140 |
|
141 | 141 | """ |
142 | | - id_in = id(input_objs[0]) if len(input_objs) > 0 else None |
143 | | - deep_copy = id(self) != id_in |
144 | | - self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) |
145 | | - self._copy_attr("applied_operations", input_objs, self.get_default_applied_operations, deep_copy) |
146 | | - self.is_batch = input_objs[0].is_batch if len(input_objs) > 0 else False |
| 142 | + self._copy_attr( |
| 143 | + ["meta", "applied_operations"], |
| 144 | + input_objs, |
| 145 | + [MetaObj.get_default_meta(), MetaObj.get_default_applied_operations()], |
| 146 | + deep_copy, |
| 147 | + ) |
147 | 148 |
|
148 | | - def get_default_meta(self) -> dict: |
| 149 | + @staticmethod |
| 150 | + def get_default_meta() -> dict: |
149 | 151 | """Get the default meta. |
150 | 152 |
|
151 | 153 | Returns: |
152 | 154 | default metadata. |
153 | 155 | """ |
154 | 156 | return {} |
155 | 157 |
|
156 | | - def get_default_applied_operations(self) -> list: |
| 158 | + @staticmethod |
| 159 | + def get_default_applied_operations() -> list: |
157 | 160 | """Get the default applied operations. |
158 | 161 |
|
159 | 162 | Returns: |
@@ -183,28 +186,28 @@ def __repr__(self) -> str: |
183 | 186 | @property |
184 | 187 | def meta(self) -> dict: |
185 | 188 | """Get the meta.""" |
186 | | - return self._meta if hasattr(self, "_meta") else self.get_default_meta() |
| 189 | + return self._meta if hasattr(self, "_meta") else MetaObj.get_default_meta() |
187 | 190 |
|
188 | 191 | @meta.setter |
189 | 192 | def meta(self, d) -> None: |
190 | 193 | """Set the meta.""" |
191 | 194 | if d == TraceKeys.NONE: |
192 | | - self._meta = self.get_default_meta() |
| 195 | + self._meta = MetaObj.get_default_meta() |
193 | 196 | self._meta = d |
194 | 197 |
|
195 | 198 | @property |
196 | 199 | def applied_operations(self) -> list: |
197 | 200 | """Get the applied operations.""" |
198 | 201 | if hasattr(self, "_applied_operations"): |
199 | 202 | return self._applied_operations |
200 | | - return self.get_default_applied_operations() |
| 203 | + return MetaObj.get_default_applied_operations() |
201 | 204 |
|
202 | 205 | @applied_operations.setter |
203 | 206 | def applied_operations(self, t) -> None: |
204 | 207 | """Set the applied operations.""" |
205 | 208 | if t == TraceKeys.NONE: |
206 | 209 | # received no operations when decollating a batch |
207 | | - self._applied_operations = self.get_default_applied_operations() |
| 210 | + self._applied_operations = MetaObj.get_default_applied_operations() |
208 | 211 | return |
209 | 212 | self._applied_operations = t |
210 | 213 |
|
|
0 commit comments