99# See the License for the specific language governing permissions and
1010# limitations under the License.
1111
12- from typing import Callable , Dict , Iterable , Optional , Sequence , Union
12+ from copy import deepcopy
13+ from typing import Callable , Dict , Hashable , Iterable , Mapping , Optional , Sequence , Union
1314
15+ import numpy as np
16+
17+ from monai .config import KeysCollection
1418from monai .data .dataset import Dataset
1519from monai .data .iterable_dataset import IterableDataset
1620from monai .data .utils import iter_patch
1721from monai .transforms import apply_transform
18- from monai .utils import NumpyPadMode , deprecated_arg , ensure_tuple , look_up_option
22+ from monai .utils import NumpyPadMode , deprecated_arg , ensure_tuple , first , look_up_option
1923
20- __all__ = ["PatchDataset" , "GridPatchDataset" , "PatchIter" ]
24+ __all__ = ["PatchDataset" , "GridPatchDataset" , "PatchIter" , "PatchIterd" ]
2125
2226
2327class PatchIter :
2428 """
25- A class to return a patch generator with predefined properties such as `patch_size`.
29+ Return a patch generator with predefined properties such as `patch_size`.
2630 Typically used with :py:class:`monai.data.GridPatchDataset`.
31+
2732 """
2833
2934 def __init__ (
@@ -42,7 +47,8 @@ def __init__(
4247 ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
4348 One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
4449 See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
45- pad_opts: padding options, see numpy.pad
50+ pad_opts: other arguments for the `np.pad` function.
51+ note that `np.pad` treats channel dimension as the first dimension.
4652
4753 Note:
4854 The `patch_size` is the size of the
@@ -52,37 +58,89 @@ def __init__(
5258 specified by a `patch_size` of (10, 10, 10).
5359
5460 """
55- self .patch_size = (None ,) + tuple (patch_size )
61+ self .patch_size = (None ,) + tuple (patch_size ) # expand to have the channel dim
5662 self .start_pos = ensure_tuple (start_pos )
5763 self .mode : NumpyPadMode = look_up_option (mode , NumpyPadMode )
5864 self .pad_opts = pad_opts
5965
60- def __call__ (self , array ):
66+ def __call__ (self , array : np . ndarray ):
6167 """
6268 Args:
6369 array: the image to generate patches from.
70+
6471 """
6572 yield from iter_patch (
6673 array ,
67- patch_size = self .patch_size , # expand to have the channel dim
74+ patch_size = self .patch_size , # type: ignore
6875 start_pos = self .start_pos ,
6976 copy_back = False ,
7077 mode = self .mode ,
7178 ** self .pad_opts ,
7279 )
7380
7481
82+ class PatchIterd :
83+ """
84+ Dictionary-based wrapper of :py:class:`monai.data.PatchIter`.
85+ Return a patch generator for dictionary data and the coordinate, Typically used
86+ with :py:class:`monai.data.GridPatchDataset`.
87+ Suppose all the expected fields specified by `keys` have same shape.
88+
89+ Args:
90+ keys: keys of the corresponding items to iterate patches.
91+ patch_size: size of patches to generate slices for, 0/None selects whole dimension
92+ start_pos: starting position in the array, default is 0 for each dimension
93+ mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
94+ ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
95+ One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
96+ See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
97+ pad_opts: other arguments for the `np.pad` function.
98+ note that `np.pad` treats channel dimension as the first dimension.
99+
100+ """
101+
102+ coords_key = "patch_coords"
103+ original_spatial_shape_key = "original_spatial_shape"
104+ start_pos_key = "start_pos"
105+
106+ def __init__ (
107+ self ,
108+ keys : KeysCollection ,
109+ patch_size : Sequence [int ],
110+ start_pos : Sequence [int ] = (),
111+ mode : Union [NumpyPadMode , str ] = NumpyPadMode .WRAP ,
112+ ** pad_opts ,
113+ ):
114+ self .keys = ensure_tuple (keys )
115+ self .patch_iter = PatchIter (patch_size = patch_size , start_pos = start_pos , mode = mode , ** pad_opts )
116+
117+ def __call__ (self , data : Mapping [Hashable , np .ndarray ]):
118+ d = dict (data )
119+ original_spatial_shape = d [first (self .keys )].shape [1 :]
120+
121+ for patch in zip (* [self .patch_iter (d [key ]) for key in self .keys ]):
122+ coords = patch [0 ][1 ] # use the coordinate of the first item
123+ ret = {k : v [0 ] for k , v in zip (self .keys , patch )}
124+ # fill in the extra keys with unmodified data
125+ for k in set (d .keys ()).difference (set (self .keys )):
126+ ret [k ] = deepcopy (d [k ])
127+ # also store the `coordinate`, `spatial shape of original image`, `start position` in the dictionary
128+ ret [self .coords_key ] = coords
129+ ret [self .original_spatial_shape_key ] = original_spatial_shape
130+ ret [self .start_pos_key ] = self .patch_iter .start_pos
131+ yield ret , coords
132+
133+
75134class GridPatchDataset (IterableDataset ):
76135 """
77- Yields patches from images read from an image dataset.
78- Typically used with `PatchIter` so that the patches are chosen in a contiguous grid sampling scheme.
136+ Yields patches from data read from an image dataset.
137+ Typically used with `PatchIter` or `PatchIterd` so that the patches are chosen in a contiguous grid sampling scheme.
79138
80139 .. code-block:: python
81140
82141 import numpy as np
83142
84- from monai.data import GridPatchDataset, DataLoader, PatchIter
85- from monai.transforms import RandShiftIntensity
143+ from monai.data import GridPatchDataset, DataLoader, PatchIter, RandShiftIntensity
86144
87145 # image-level dataset
88146 images = [np.arange(16, dtype=float).reshape(1, 4, 4),
@@ -109,7 +167,7 @@ class GridPatchDataset(IterableDataset):
109167 data: the data source to read image data from.
110168 patch_iter: converts an input image (item from dataset) into a iterable of image patches.
111169 `patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates).
112- see also: :py:class:`monai.data.PatchIter`.
170+ see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd` .
113171 transform: a callable data transform operates on the patches.
114172 with_coordinates: whether to yield the coordinates of each patch, default to `True`.
115173
0 commit comments