Skip to content

Commit b6b2cfb

Browse files
rijobrowyli
andauthored
SplitDim (#3884)
* SplitDim Signed-off-by: Richard Brown <[email protected]> * fix Signed-off-by: Richard Brown <[email protected]> * fixes Signed-off-by: Richard Brown <[email protected]> * fix update meta Signed-off-by: Richard Brown <[email protected]> * update docs Signed-off-by: Wenqi Li <[email protected]> Co-authored-by: Wenqi Li <[email protected]>
1 parent e8146e9 commit b6b2cfb

File tree

7 files changed

+239
-30
lines changed

7 files changed

+239
-30
lines changed

docs/source/transforms.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,12 @@ Utility
803803
:members:
804804
:special-members: __call__
805805

806+
`SplitDim`
807+
""""""""""
808+
.. autoclass:: SplitDim
809+
:members:
810+
:special-members: __call__
811+
806812
`SplitChannel`
807813
""""""""""""""
808814
.. autoclass:: SplitChannel
@@ -1638,6 +1644,12 @@ Utility (Dict)
16381644
:members:
16391645
:special-members: __call__
16401646

1647+
`SplitDimd`
1648+
"""""""""""
1649+
.. autoclass:: SplitDimd
1650+
:members:
1651+
:special-members: __call__
1652+
16411653
`SplitChanneld`
16421654
"""""""""""""""
16431655
.. autoclass:: SplitChanneld

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@
412412
RepeatChannel,
413413
SimulateDelay,
414414
SplitChannel,
415+
SplitDim,
415416
SqueezeDim,
416417
ToCupy,
417418
ToDevice,
@@ -509,6 +510,9 @@
509510
SplitChanneld,
510511
SplitChannelD,
511512
SplitChannelDict,
513+
SplitDimd,
514+
SplitDimD,
515+
SplitDimDict,
512516
SqueezeDimd,
513517
SqueezeDimD,
514518
SqueezeDimDict,

monai/transforms/utility/array.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
convert_to_cupy,
3838
convert_to_numpy,
3939
convert_to_tensor,
40+
deprecated,
4041
deprecated_arg,
4142
ensure_tuple,
4243
look_up_option,
@@ -62,6 +63,7 @@
6263
"EnsureType",
6364
"RepeatChannel",
6465
"RemoveRepeatedChannel",
66+
"SplitDim",
6567
"SplitChannel",
6668
"CastToType",
6769
"ToTensor",
@@ -281,33 +283,57 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
281283
return img[:: self.repeats, :]
282284

283285

284-
class SplitChannel(Transform):
286+
class SplitDim(Transform):
285287
"""
286-
Split Numpy array or PyTorch Tensor data according to the channel dim.
287-
It can help applying different following transforms to different channels.
288+
Given an image of size X along a certain dimension, return a list of length X containing
289+
images. Useful for converting 3D images into a stack of 2D images, splitting multichannel inputs into
290+
single channels, for example.
288291
289-
Args:
290-
channel_dim: which dimension of input image is the channel, default to 0.
292+
Note: `torch.split`/`np.split` is used, so the outputs are views of the input (shallow copy).
291293
294+
Args:
295+
dim: dimension on which to split
296+
keepdim: if `True`, output will have singleton in the split dimension. If `False`, this
297+
dimension will be squeezed.
292298
"""
293299

294300
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
295301

296-
def __init__(self, channel_dim: int = 0) -> None:
297-
self.channel_dim = channel_dim
302+
def __init__(self, dim: int = -1, keepdim: bool = True) -> None:
303+
self.dim = dim
304+
self.keepdim = keepdim
298305

299306
def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]:
300-
num_classes = img.shape[self.channel_dim]
301-
if num_classes <= 1:
302-
raise RuntimeError("input image does not contain multiple channels.")
307+
"""
308+
Apply the transform to `img`.
309+
"""
310+
n_out = img.shape[self.dim]
311+
if n_out <= 1:
312+
raise RuntimeError("Input image is singleton along dimension to be split.")
313+
if isinstance(img, torch.Tensor):
314+
outputs = list(torch.split(img, 1, self.dim))
315+
else:
316+
outputs = np.split(img, n_out, self.dim)
317+
if not self.keepdim:
318+
outputs = [o.squeeze(self.dim) for o in outputs]
319+
return outputs
303320

304-
outputs = []
305-
slices = [slice(None)] * len(img.shape)
306-
for i in range(num_classes):
307-
slices[self.channel_dim] = slice(i, i + 1)
308-
outputs.append(img[tuple(slices)])
309321

310-
return outputs
322+
@deprecated(since="0.8", msg_suffix="please use `SplitDim` instead.")
323+
class SplitChannel(SplitDim):
324+
"""
325+
Split Numpy array or PyTorch Tensor data according to the channel dim.
326+
It can help applying different following transforms to different channels.
327+
328+
Note: `torch.split`/`np.split` is used, so the outputs are views of the input (shallow copy).
329+
330+
Args:
331+
channel_dim: which dimension of input image is the channel, default to 0.
332+
333+
"""
334+
335+
def __init__(self, channel_dim: int = 0) -> None:
336+
super().__init__(channel_dim)
311337

312338

313339
class CastToType(Transform):

monai/transforms/utility/dictionary.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
RemoveRepeatedChannel,
5050
RepeatChannel,
5151
SimulateDelay,
52-
SplitChannel,
52+
SplitDim,
5353
SqueezeDim,
5454
ToCupy,
5555
ToDevice,
@@ -61,7 +61,7 @@
6161
)
6262
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
6363
from monai.transforms.utils_pytorch_numpy_unification import concatenate
64-
from monai.utils import convert_to_numpy, deprecated_arg, ensure_tuple, ensure_tuple_rep
64+
from monai.utils import convert_to_numpy, deprecated, deprecated_arg, ensure_tuple, ensure_tuple_rep
6565
from monai.utils.enums import PostFix, TraceKeys, TransformBackends
6666
from monai.utils.type_conversion import convert_to_dst_type
6767

@@ -150,6 +150,9 @@
150150
"SplitChannelD",
151151
"SplitChannelDict",
152152
"SplitChanneld",
153+
"SplitDimD",
154+
"SplitDimDict",
155+
"SplitDimd",
153156
"SqueezeDimD",
154157
"SqueezeDimDict",
155158
"SqueezeDimd",
@@ -372,19 +375,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
372375
return d
373376

374377

375-
class SplitChanneld(MapTransform):
376-
"""
377-
Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`.
378-
All the input specified by `keys` should be split into same count of data.
379-
"""
380-
381-
backend = SplitChannel.backend
382-
378+
class SplitDimd(MapTransform):
383379
def __init__(
384380
self,
385381
keys: KeysCollection,
386382
output_postfixes: Optional[Sequence[str]] = None,
387-
channel_dim: int = 0,
383+
dim: int = 0,
384+
keepdim: bool = True,
385+
update_meta: bool = True,
388386
allow_missing_keys: bool = False,
389387
) -> None:
390388
"""
@@ -395,13 +393,17 @@ def __init__(
395393
for example: if the key of input data is `pred` and split 2 classes, the output
396394
data keys will be: pred_(output_postfixes[0]), pred_(output_postfixes[1])
397395
if None, using the index number: `pred_0`, `pred_1`, ... `pred_N`.
398-
channel_dim: which dimension of input image is the channel, default to 0.
396+
dim: which dimension of input image is the channel, default to 0.
397+
keepdim: if `True`, output will have singleton in the split dimension. If `False`, this
398+
dimension will be squeezed.
399+
update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to
400+
reflect the cropped image
399401
allow_missing_keys: don't raise exception if key is missing.
400-
401402
"""
402403
super().__init__(keys, allow_missing_keys)
403404
self.output_postfixes = output_postfixes
404-
self.splitter = SplitChannel(channel_dim=channel_dim)
405+
self.splitter = SplitDim(dim, keepdim)
406+
self.update_meta = update_meta
405407

406408
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
407409
d = dict(data)
@@ -415,9 +417,44 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
415417
if split_key in d:
416418
raise RuntimeError(f"input data already contains key {split_key}.")
417419
d[split_key] = r
420+
421+
if self.update_meta:
422+
orig_meta = d.get(PostFix.meta(key), None)
423+
if orig_meta is not None:
424+
split_meta_key = PostFix.meta(split_key)
425+
d[split_meta_key] = deepcopy(orig_meta)
426+
dim = self.splitter.dim
427+
if dim > 0: # don't update affine if channel dim
428+
shift = np.eye(len(d[split_meta_key]["affine"])) # type: ignore
429+
shift[dim - 1, -1] = i # type: ignore
430+
d[split_meta_key]["affine"] = d[split_meta_key]["affine"] @ shift # type: ignore
431+
418432
return d
419433

420434

435+
@deprecated(since="0.8", msg_suffix="please use `SplitDimd` instead.")
436+
class SplitChanneld(SplitDimd):
437+
"""
438+
Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`.
439+
All the input specified by `keys` should be split into same count of data.
440+
"""
441+
442+
def __init__(
443+
self,
444+
keys: KeysCollection,
445+
output_postfixes: Optional[Sequence[str]] = None,
446+
channel_dim: int = 0,
447+
allow_missing_keys: bool = False,
448+
) -> None:
449+
super().__init__(
450+
keys,
451+
output_postfixes=output_postfixes,
452+
dim=channel_dim,
453+
update_meta=False, # for backwards compatibility
454+
allow_missing_keys=allow_missing_keys,
455+
)
456+
457+
421458
class CastToTyped(MapTransform):
422459
"""
423460
Dictionary-based wrapper of :py:class:`monai.transforms.CastToType`.
@@ -1637,6 +1674,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
16371674
RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld
16381675
RepeatChannelD = RepeatChannelDict = RepeatChanneld
16391676
SplitChannelD = SplitChannelDict = SplitChanneld
1677+
SplitDimD = SplitDimDict = SplitDimd
16401678
CastToTypeD = CastToTypeDict = CastToTyped
16411679
ToTensorD = ToTensorDict = ToTensord
16421680
EnsureTypeD = EnsureTypeDict = EnsureTyped

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def run_testsuit():
143143
"test_smartcachedataset",
144144
"test_spacing",
145145
"test_spacingd",
146+
"test_splitdimd",
146147
"test_surface_distance",
147148
"test_testtimeaugmentation",
148149
"test_torchvision",

tests/test_splitdim.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
from parameterized import parameterized
16+
17+
from monai.transforms.utility.array import SplitDim
18+
from tests.utils import TEST_NDARRAYS
19+
20+
TESTS = []
21+
for p in TEST_NDARRAYS:
22+
for keepdim in (True, False):
23+
TESTS.append(((2, 10, 8, 7), keepdim, p))
24+
25+
26+
class TestSplitDim(unittest.TestCase):
27+
@parameterized.expand(TESTS)
28+
def test_correct_shape(self, shape, keepdim, im_type):
29+
arr = im_type(np.random.rand(*shape))
30+
for dim in range(arr.ndim):
31+
out = SplitDim(dim, keepdim)(arr)
32+
self.assertIsInstance(out, (list, tuple))
33+
self.assertEqual(len(out), arr.shape[dim])
34+
expected_ndim = arr.ndim if keepdim else arr.ndim - 1
35+
self.assertEqual(out[0].ndim, expected_ndim)
36+
# assert is a shallow copy
37+
arr[0, 0, 0, 0] *= 2
38+
self.assertEqual(arr.flatten()[0], out[0].flatten()[0])
39+
40+
def test_error(self):
41+
"""Should fail because splitting along singleton dimension"""
42+
shape = (2, 1, 8, 7)
43+
for p in TEST_NDARRAYS:
44+
arr = p(np.random.rand(*shape))
45+
with self.assertRaises(RuntimeError):
46+
_ = SplitDim(dim=1)(arr)
47+
48+
49+
if __name__ == "__main__":
50+
unittest.main()

tests/test_splitdimd.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
from copy import deepcopy
14+
15+
import numpy as np
16+
from parameterized import parameterized
17+
18+
from monai.transforms import LoadImaged
19+
from monai.transforms.utility.dictionary import SplitDimd
20+
from tests.utils import TEST_NDARRAYS, assert_allclose, make_nifti_image, make_rand_affine
21+
22+
TESTS = []
23+
for p in TEST_NDARRAYS:
24+
for keepdim in (True, False):
25+
for update_meta in (True, False):
26+
TESTS.append((keepdim, p, update_meta))
27+
28+
29+
class TestSplitDimd(unittest.TestCase):
30+
@classmethod
31+
def setUpClass(cls):
32+
arr = np.random.rand(2, 10, 8, 7)
33+
affine = make_rand_affine()
34+
data = {"i": make_nifti_image(arr, affine)}
35+
36+
cls.data = LoadImaged("i")(data)
37+
38+
@parameterized.expand(TESTS)
39+
def test_correct(self, keepdim, im_type, update_meta):
40+
data = deepcopy(self.data)
41+
data["i"] = im_type(data["i"])
42+
arr = data["i"]
43+
for dim in range(arr.ndim):
44+
out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta)(data)
45+
self.assertIsInstance(out, dict)
46+
num_new_keys = 2 if update_meta else 1
47+
self.assertEqual(len(out.keys()), len(data.keys()) + num_new_keys * arr.shape[dim])
48+
# if updating meta data, pick some random points and
49+
# check same world coordinates between input and output
50+
if update_meta:
51+
for _ in range(10):
52+
idx = [np.random.choice(i) for i in arr.shape]
53+
split_im_idx = idx[dim]
54+
split_idx = deepcopy(idx)
55+
split_idx[dim] = 0
56+
# idx[1:] to remove channel and then add 1 for 4th element
57+
real_world = data["i_meta_dict"]["affine"] @ (idx[1:] + [1])
58+
real_world2 = out[f"i_{split_im_idx}_meta_dict"]["affine"] @ (split_idx[1:] + [1])
59+
assert_allclose(real_world, real_world2)
60+
61+
out = out["i_0"]
62+
expected_ndim = arr.ndim if keepdim else arr.ndim - 1
63+
self.assertEqual(out.ndim, expected_ndim)
64+
# assert is a shallow copy
65+
arr[0, 0, 0, 0] *= 2
66+
self.assertEqual(arr.flatten()[0], out.flatten()[0])
67+
68+
def test_error(self):
69+
"""Should fail because splitting along singleton dimension"""
70+
shape = (2, 1, 8, 7)
71+
for p in TEST_NDARRAYS:
72+
arr = p(np.random.rand(*shape))
73+
with self.assertRaises(RuntimeError):
74+
_ = SplitDimd("i", dim=1)({"i": arr})
75+
76+
77+
if __name__ == "__main__":
78+
unittest.main()

0 commit comments

Comments
 (0)