Skip to content

Commit c51d85e

Browse files
4060 Add PatchIter and PatchIterd transform (#4061)
* [DLMED] change PatchIter to be a transform Signed-off-by: Nic Ma <[email protected]> * [DLMED] add dict transform Signed-off-by: Nic Ma <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [DLMED] add unit tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] store coords in dict Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]> * [DLMED] restore the doc-string Signed-off-by: Nic Ma <[email protected]> * [DLMED] store more info Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c5bf120 commit c51d85e

File tree

5 files changed

+129
-22
lines changed

5 files changed

+129
-22
lines changed

docs/source/data.rst

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,23 @@ Patch-based dataset
107107
.. autoclass:: GridPatchDataset
108108
:members:
109109

110-
`PatchIter`
111-
~~~~~~~~~~~
112-
.. autoclass:: PatchIter
113-
:members:
114-
115110
`PatchDataset`
116111
~~~~~~~~~~~~~~
117112
.. autoclass:: PatchDataset
118113
:members:
119114

115+
`PatchIter`
116+
"""""""""""
117+
.. autoclass:: PatchIter
118+
:members:
119+
:special-members: __call__
120+
121+
`PatchIterd`
122+
""""""""""""
123+
.. autoclass:: PatchIterd
124+
:members:
125+
:special-members: __call__
126+
120127
Image reader
121128
------------
122129

monai/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
load_decathlon_properties,
3333
)
3434
from .folder_layout import FolderLayout
35-
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter
35+
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd
3636
from .image_dataset import ImageDataset
3737
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
3838
from .image_writer import (

monai/data/grid_dataset.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,26 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, Dict, Iterable, Optional, Sequence, Union
12+
from copy import deepcopy
13+
from typing import Callable, Dict, Hashable, Iterable, Mapping, Optional, Sequence, Union
1314

15+
import numpy as np
16+
17+
from monai.config import KeysCollection
1418
from monai.data.dataset import Dataset
1519
from monai.data.iterable_dataset import IterableDataset
1620
from monai.data.utils import iter_patch
1721
from monai.transforms import apply_transform
18-
from monai.utils import NumpyPadMode, deprecated_arg, ensure_tuple, look_up_option
22+
from monai.utils import NumpyPadMode, deprecated_arg, ensure_tuple, first, look_up_option
1923

20-
__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter"]
24+
__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"]
2125

2226

2327
class PatchIter:
2428
"""
25-
A class to return a patch generator with predefined properties such as `patch_size`.
29+
Return a patch generator with predefined properties such as `patch_size`.
2630
Typically used with :py:class:`monai.data.GridPatchDataset`.
31+
2732
"""
2833

2934
def __init__(
@@ -42,7 +47,8 @@ def __init__(
4247
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
4348
One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
4449
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
45-
pad_opts: padding options, see numpy.pad
50+
pad_opts: other arguments for the `np.pad` function.
51+
note that `np.pad` treats channel dimension as the first dimension.
4652
4753
Note:
4854
The `patch_size` is the size of the
@@ -52,37 +58,89 @@ def __init__(
5258
specified by a `patch_size` of (10, 10, 10).
5359
5460
"""
55-
self.patch_size = (None,) + tuple(patch_size)
61+
self.patch_size = (None,) + tuple(patch_size) # expand to have the channel dim
5662
self.start_pos = ensure_tuple(start_pos)
5763
self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode)
5864
self.pad_opts = pad_opts
5965

60-
def __call__(self, array):
66+
def __call__(self, array: np.ndarray):
6167
"""
6268
Args:
6369
array: the image to generate patches from.
70+
6471
"""
6572
yield from iter_patch(
6673
array,
67-
patch_size=self.patch_size, # expand to have the channel dim
74+
patch_size=self.patch_size, # type: ignore
6875
start_pos=self.start_pos,
6976
copy_back=False,
7077
mode=self.mode,
7178
**self.pad_opts,
7279
)
7380

7481

82+
class PatchIterd:
83+
"""
84+
Dictionary-based wrapper of :py:class:`monai.data.PatchIter`.
85+
Return a patch generator for dictionary data and the coordinate, Typically used
86+
with :py:class:`monai.data.GridPatchDataset`.
87+
Suppose all the expected fields specified by `keys` have same shape.
88+
89+
Args:
90+
keys: keys of the corresponding items to iterate patches.
91+
patch_size: size of patches to generate slices for, 0/None selects whole dimension
92+
start_pos: starting position in the array, default is 0 for each dimension
93+
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
94+
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
95+
One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
96+
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
97+
pad_opts: other arguments for the `np.pad` function.
98+
note that `np.pad` treats channel dimension as the first dimension.
99+
100+
"""
101+
102+
coords_key = "patch_coords"
103+
original_spatial_shape_key = "original_spatial_shape"
104+
start_pos_key = "start_pos"
105+
106+
def __init__(
107+
self,
108+
keys: KeysCollection,
109+
patch_size: Sequence[int],
110+
start_pos: Sequence[int] = (),
111+
mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP,
112+
**pad_opts,
113+
):
114+
self.keys = ensure_tuple(keys)
115+
self.patch_iter = PatchIter(patch_size=patch_size, start_pos=start_pos, mode=mode, **pad_opts)
116+
117+
def __call__(self, data: Mapping[Hashable, np.ndarray]):
118+
d = dict(data)
119+
original_spatial_shape = d[first(self.keys)].shape[1:]
120+
121+
for patch in zip(*[self.patch_iter(d[key]) for key in self.keys]):
122+
coords = patch[0][1] # use the coordinate of the first item
123+
ret = {k: v[0] for k, v in zip(self.keys, patch)}
124+
# fill in the extra keys with unmodified data
125+
for k in set(d.keys()).difference(set(self.keys)):
126+
ret[k] = deepcopy(d[k])
127+
# also store the `coordinate`, `spatial shape of original image`, `start position` in the dictionary
128+
ret[self.coords_key] = coords
129+
ret[self.original_spatial_shape_key] = original_spatial_shape
130+
ret[self.start_pos_key] = self.patch_iter.start_pos
131+
yield ret, coords
132+
133+
75134
class GridPatchDataset(IterableDataset):
76135
"""
77-
Yields patches from images read from an image dataset.
78-
Typically used with `PatchIter` so that the patches are chosen in a contiguous grid sampling scheme.
136+
Yields patches from data read from an image dataset.
137+
Typically used with `PatchIter` or `PatchIterd` so that the patches are chosen in a contiguous grid sampling scheme.
79138
80139
.. code-block:: python
81140
82141
import numpy as np
83142
84-
from monai.data import GridPatchDataset, DataLoader, PatchIter
85-
from monai.transforms import RandShiftIntensity
143+
from monai.data import GridPatchDataset, DataLoader, PatchIter, RandShiftIntensity
86144
87145
# image-level dataset
88146
images = [np.arange(16, dtype=float).reshape(1, 4, 4),
@@ -109,7 +167,7 @@ class GridPatchDataset(IterableDataset):
109167
data: the data source to read image data from.
110168
patch_iter: converts an input image (item from dataset) into a iterable of image patches.
111169
`patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates).
112-
see also: :py:class:`monai.data.PatchIter`.
170+
see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd`.
113171
transform: a callable data transform operates on the patches.
114172
with_coordinates: whether to yield the coordinates of each patch, default to `True`.
115173

monai/transforms/croppad/dictionary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
"RandCropByPosNegLabeld",
7171
"ResizeWithPadOrCropd",
7272
"BoundingRectd",
73+
"RandCropByLabelClassesd",
7374
"SpatialPadD",
7475
"SpatialPadDict",
7576
"BorderPadD",
@@ -98,7 +99,6 @@
9899
"ResizeWithPadOrCropDict",
99100
"BoundingRectD",
100101
"BoundingRectDict",
101-
"RandCropByLabelClassesd",
102102
"RandCropByLabelClassesD",
103103
"RandCropByLabelClassesDict",
104104
]

tests/test_grid_dataset.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
import numpy as np
1616

17-
from monai.data import DataLoader, GridPatchDataset, PatchIter
18-
from monai.transforms import RandShiftIntensity
17+
from monai.data import DataLoader, GridPatchDataset, PatchIter, PatchIterd
18+
from monai.transforms import RandShiftIntensity, RandShiftIntensityd
1919
from monai.utils import set_determinism
2020

2121

@@ -76,6 +76,48 @@ def test_loading_array(self):
7676
item[1], np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5
7777
)
7878

79+
def test_loading_dict(self):
80+
set_determinism(seed=1234)
81+
# test sequence input data with dict
82+
data = [
83+
{
84+
"image": np.arange(16, dtype=float).reshape(1, 4, 4),
85+
"label": np.arange(16, dtype=float).reshape(1, 4, 4),
86+
"metadata": "test string",
87+
},
88+
{
89+
"image": np.arange(16, dtype=float).reshape(1, 4, 4),
90+
"label": np.arange(16, dtype=float).reshape(1, 4, 4),
91+
"metadata": "test string",
92+
},
93+
]
94+
# image level
95+
patch_intensity = RandShiftIntensityd(keys="image", offsets=1.0, prob=1.0)
96+
patch_iter = PatchIterd(keys=["image", "label"], patch_size=(2, 2), start_pos=(0, 0))
97+
ds = GridPatchDataset(data=data, patch_iter=patch_iter, transform=patch_intensity, with_coordinates=True)
98+
# use the grid patch dataset
99+
for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0):
100+
np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2))
101+
np.testing.assert_equal(item[0]["label"].shape, (2, 1, 2, 2))
102+
self.assertListEqual(item[0]["metadata"], ["test string", "test string"])
103+
np.testing.assert_allclose(
104+
item[0]["image"],
105+
np.array([[[[1.4965, 2.4965], [5.4965, 6.4965]]], [[[11.3584, 12.3584], [15.3584, 16.3584]]]]),
106+
rtol=1e-4,
107+
)
108+
np.testing.assert_allclose(item[1], np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
109+
if sys.platform != "win32":
110+
for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=2):
111+
np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2))
112+
np.testing.assert_allclose(
113+
item[0]["image"],
114+
np.array([[[[1.2548, 2.2548], [5.2548, 6.2548]]], [[[9.1106, 10.1106], [13.1106, 14.1106]]]]),
115+
rtol=1e-3,
116+
)
117+
np.testing.assert_allclose(
118+
item[1], np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5
119+
)
120+
79121

80122
if __name__ == "__main__":
81123
unittest.main()

0 commit comments

Comments
 (0)