Skip to content

Commit 245ab94

Browse files
authored
Split On Grid (#2879)
* Implement SplitOnGrid Signed-off-by: Behrooz <[email protected]> * Implement dictionary-based SplitOnGrid Signed-off-by: Behrooz <[email protected]> * Update inits Signed-off-by: Behrooz <[email protected]> * Update docs Signed-off-by: Behrooz <[email protected]> * Change imports Signed-off-by: Behrooz <[email protected]> * Update input logic in SplitOnGrid) Signed-off-by: Behrooz <[email protected]> * Add unittests for SplitOnGrid and SplitOnGridDict Signed-off-by: Behrooz <[email protected]> * Sort import Signed-off-by: Behrooz <[email protected]> * Remove imports Signed-off-by: Behrooz <[email protected]> * Address comments Signed-off-by: Behrooz <[email protected]> * Remove optional Signed-off-by: Behrooz <[email protected]> * Address thread safety issues Signed-off-by: Behrooz <[email protected]>
1 parent 50dcf8d commit 245ab94

File tree

7 files changed

+418
-0
lines changed

7 files changed

+418
-0
lines changed

docs/source/apps.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,11 @@ Clara MMARs
110110
:members:
111111
.. autoclass:: NormalizeHEStainsd
112112
:members:
113+
114+
.. automodule:: monai.apps.pathology.transforms.spatial.array
115+
.. autoclass:: SplitOnGrid
116+
:members:
117+
118+
.. automodule:: monai.apps.pathology.transforms.spatial.dictionary
119+
.. autoclass:: SplitOnGridd
120+
:members:

monai/apps/pathology/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from .spatial.array import SplitOnGrid
13+
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
1214
from .stain.array import ExtractHEStains, NormalizeHEStains
1315
from .stain.dictionary import (
1416
ExtractHEStainsd,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from .array import SplitOnGrid
13+
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Optional, Tuple, Union
13+
14+
import torch
15+
16+
from monai.transforms.transform import Transform
17+
18+
__all__ = ["SplitOnGrid"]
19+
20+
21+
class SplitOnGrid(Transform):
22+
"""
23+
Split the image into patches based on the provided grid shape.
24+
This transform works only with torch.Tensor inputs.
25+
26+
Args:
27+
grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches.
28+
If it's an integer, the value will be repeated for each dimension. Default is 2x2
29+
patch_size: a tuple or an integer that defines the output patch sizes.
30+
If it's an integer, the value will be repeated for each dimension.
31+
The default is (0, 0), where the patch size will be infered from the grid shape.
32+
33+
Note: the shape of the input image is infered based on the first image used.
34+
"""
35+
36+
def __init__(
37+
self,
38+
grid_size: Union[int, Tuple[int, int]] = (2, 2),
39+
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
40+
):
41+
# Grid size
42+
if isinstance(grid_size, int):
43+
self.grid_size = (grid_size, grid_size)
44+
else:
45+
self.grid_size = grid_size
46+
# Patch size
47+
self.patch_size = None
48+
if isinstance(patch_size, int):
49+
self.patch_size = (patch_size, patch_size)
50+
else:
51+
self.patch_size = patch_size
52+
53+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
54+
if self.grid_size == (1, 1) and self.patch_size is None:
55+
return torch.stack([image])
56+
patch_size, steps = self.get_params(image.shape[1:])
57+
patches = (
58+
image.unfold(1, patch_size[0], steps[0])
59+
.unfold(2, patch_size[1], steps[1])
60+
.flatten(1, 2)
61+
.transpose(0, 1)
62+
.contiguous()
63+
)
64+
return patches
65+
66+
def get_params(self, image_size):
67+
if self.patch_size is None:
68+
patch_size = tuple(image_size[i] // self.grid_size[i] for i in range(2))
69+
else:
70+
patch_size = self.patch_size
71+
72+
steps = tuple(
73+
(image_size[i] - patch_size[i]) // (self.grid_size[i] - 1) if self.grid_size[i] > 1 else image_size[i]
74+
for i in range(2)
75+
)
76+
77+
return patch_size, steps
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Dict, Hashable, Mapping, Optional, Tuple, Union
13+
14+
import torch
15+
16+
from monai.config import KeysCollection
17+
from monai.transforms.transform import MapTransform
18+
19+
from .array import SplitOnGrid
20+
21+
__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"]
22+
23+
24+
class SplitOnGridd(MapTransform):
25+
"""
26+
Split the image into patches based on the provided grid shape.
27+
This transform works only with torch.Tensor inputs.
28+
29+
Args:
30+
grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches.
31+
If it's an integer, the value will be repeated for each dimension. Default is 2x2
32+
patch_size: a tuple or an integer that defines the output patch sizes.
33+
If it's an integer, the value will be repeated for each dimension.
34+
The default is (0, 0), where the patch size will be infered from the grid shape.
35+
36+
Note: the shape of the input image is infered based on the first image used.
37+
"""
38+
39+
def __init__(
40+
self,
41+
keys: KeysCollection,
42+
grid_size: Union[int, Tuple[int, int]] = (2, 2),
43+
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
44+
allow_missing_keys: bool = False,
45+
):
46+
super().__init__(keys, allow_missing_keys)
47+
self.splitter = SplitOnGrid(grid_size=grid_size, patch_size=patch_size)
48+
49+
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
50+
d = dict(data)
51+
for key in self.key_iterator(d):
52+
d[key] = self.splitter(d[key])
53+
return d
54+
55+
56+
SplitOnGridDict = SplitOnGridD = SplitOnGridd

tests/test_split_on_grid.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
import torch
16+
from parameterized import parameterized
17+
18+
from monai.apps.pathology.transforms import SplitOnGrid
19+
20+
A11 = torch.randn(3, 2, 2)
21+
A12 = torch.randn(3, 2, 2)
22+
A21 = torch.randn(3, 2, 2)
23+
A22 = torch.randn(3, 2, 2)
24+
25+
A1 = torch.cat([A11, A12], 2)
26+
A2 = torch.cat([A21, A22], 2)
27+
A = torch.cat([A1, A2], 1)
28+
29+
TEST_CASE_0 = [
30+
{"grid_size": (2, 2)},
31+
A,
32+
torch.stack([A11, A12, A21, A22]),
33+
]
34+
35+
TEST_CASE_1 = [
36+
{"grid_size": (2, 1)},
37+
A,
38+
torch.stack([A1, A2]),
39+
]
40+
41+
TEST_CASE_2 = [
42+
{"grid_size": (1, 2)},
43+
A1,
44+
torch.stack([A11, A12]),
45+
]
46+
47+
TEST_CASE_3 = [
48+
{"grid_size": (1, 2)},
49+
A2,
50+
torch.stack([A21, A22]),
51+
]
52+
53+
TEST_CASE_4 = [
54+
{"grid_size": (1, 1), "patch_size": (2, 2)},
55+
A,
56+
torch.stack([A11]),
57+
]
58+
59+
TEST_CASE_5 = [
60+
{"grid_size": 1, "patch_size": 4},
61+
A,
62+
torch.stack([A]),
63+
]
64+
65+
TEST_CASE_6 = [
66+
{"grid_size": 2, "patch_size": 2},
67+
A,
68+
torch.stack([A11, A12, A21, A22]),
69+
]
70+
71+
TEST_CASE_7 = [
72+
{"grid_size": 1},
73+
A,
74+
torch.stack([A]),
75+
]
76+
77+
TEST_CASE_MC_0 = [
78+
{"grid_size": (2, 2)},
79+
[A, A],
80+
[torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])],
81+
]
82+
83+
84+
TEST_CASE_MC_1 = [
85+
{"grid_size": (2, 1)},
86+
[A] * 5,
87+
[torch.stack([A1, A2])] * 5,
88+
]
89+
90+
91+
TEST_CASE_MC_2 = [
92+
{"grid_size": (1, 2)},
93+
[A1, A2],
94+
[torch.stack([A11, A12]), torch.stack([A21, A22])],
95+
]
96+
97+
98+
class TestSplitOnGrid(unittest.TestCase):
99+
@parameterized.expand(
100+
[
101+
TEST_CASE_0,
102+
TEST_CASE_1,
103+
TEST_CASE_2,
104+
TEST_CASE_3,
105+
TEST_CASE_4,
106+
TEST_CASE_5,
107+
TEST_CASE_6,
108+
TEST_CASE_7,
109+
]
110+
)
111+
def test_split_pathce_single_call(self, input_parameters, img, expected):
112+
splitter = SplitOnGrid(**input_parameters)
113+
output = splitter(img)
114+
np.testing.assert_equal(output.numpy(), expected.numpy())
115+
116+
@parameterized.expand(
117+
[
118+
TEST_CASE_MC_0,
119+
TEST_CASE_MC_1,
120+
TEST_CASE_MC_2,
121+
]
122+
)
123+
def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list):
124+
splitter = SplitOnGrid(**input_parameters)
125+
for img, expected in zip(img_list, expected_list):
126+
output = splitter(img)
127+
np.testing.assert_equal(output.numpy(), expected.numpy())
128+
129+
130+
if __name__ == "__main__":
131+
unittest.main()

0 commit comments

Comments
 (0)