Skip to content

Commit 28856b8

Browse files
authored
2789 Add ToDevice transform (#2791)
* [DLMED] add ToDevice transform Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix type-hints Signed-off-by: Nic Ma <[email protected]> * [DLMED] inherit Transform Signed-off-by: Nic Ma <[email protected]> * [DLMED] add kwargs Signed-off-by: Nic Ma <[email protected]>
1 parent 46d5f2a commit 28856b8

File tree

6 files changed

+155
-0
lines changed

6 files changed

+155
-0
lines changed

docs/source/transforms.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,12 @@ Utility
718718
:members:
719719
:special-members: __call__
720720

721+
`ToDevice`
722+
""""""""""
723+
.. autoclass:: ToDevice
724+
:members:
725+
:special-members: __call__
726+
721727

722728
Dictionary Transforms
723729
---------------------
@@ -1347,6 +1353,12 @@ Utility (Dict)
13471353
:members:
13481354
:special-members: __call__
13491355

1356+
`ToDeviced`
1357+
"""""""""""
1358+
.. autoclass:: ToDeviced
1359+
:members:
1360+
:special-members: __call__
1361+
13501362

13511363
Transform Adaptors
13521364
------------------

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@
377377
SplitChannel,
378378
SqueezeDim,
379379
ToCupy,
380+
ToDevice,
380381
ToNumpy,
381382
ToPIL,
382383
TorchVision,
@@ -468,6 +469,9 @@
468469
ToCupyd,
469470
ToCupyD,
470471
ToCupyDict,
472+
ToDeviced,
473+
ToDeviceD,
474+
ToDeviceDict,
471475
ToNumpyd,
472476
ToNumpyD,
473477
ToNumpyDict,

monai/transforms/utility/array.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"TorchVision",
7474
"MapLabelValue",
7575
"IntensityStats",
76+
"ToDevice",
7677
]
7778

7879

@@ -1021,3 +1022,28 @@ def _compute(op: Callable, data: np.ndarray):
10211022
raise ValueError("ops must be key string for predefined operations or callable function.")
10221023

10231024
return img, meta_data
1025+
1026+
1027+
class ToDevice(Transform):
1028+
"""
1029+
Move PyTorch Tensor to the specified device.
1030+
It can help cache data into GPU and execute following logic on GPU directly.
1031+
1032+
"""
1033+
1034+
def __init__(self, device: Union[torch.device, str], **kwargs) -> None:
1035+
"""
1036+
Args:
1037+
device: target device to move the Tensor, for example: "cuda:1".
1038+
kwargs: other args for the PyTorch `Tensor.to()` API, for more details:
1039+
https://pytorch.org/docs/stable/generated/torch.Tensor.to.html.
1040+
1041+
"""
1042+
self.device = device
1043+
self.kwargs = kwargs
1044+
1045+
def __call__(self, img: torch.Tensor):
1046+
if not isinstance(img, torch.Tensor):
1047+
raise ValueError("img must be PyTorch Tensor, consider converting img by `EnsureType` transform first.")
1048+
1049+
return img.to(self.device, **self.kwargs)

monai/transforms/utility/dictionary.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
SplitChannel,
5050
SqueezeDim,
5151
ToCupy,
52+
ToDevice,
5253
ToNumpy,
5354
ToPIL,
5455
TorchVision,
@@ -141,6 +142,9 @@
141142
"ToCupyD",
142143
"ToCupyDict",
143144
"ToCupyd",
145+
"ToDeviced",
146+
"ToDeviceD",
147+
"ToDeviceDict",
144148
"ToNumpyD",
145149
"ToNumpyDict",
146150
"ToNumpyd",
@@ -1354,6 +1358,37 @@ def __call__(self, data) -> Dict[Hashable, np.ndarray]:
13541358
return d
13551359

13561360

1361+
class ToDeviced(MapTransform):
1362+
"""
1363+
Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`.
1364+
"""
1365+
1366+
def __init__(
1367+
self,
1368+
keys: KeysCollection,
1369+
device: Union[torch.device, str],
1370+
allow_missing_keys: bool = False,
1371+
**kwargs,
1372+
) -> None:
1373+
"""
1374+
Args:
1375+
keys: keys of the corresponding items to be transformed.
1376+
See also: :py:class:`monai.transforms.compose.MapTransform`
1377+
device: target device to move the Tensor, for example: "cuda:1".
1378+
allow_missing_keys: don't raise exception if key is missing.
1379+
kwargs: other args for the PyTorch `Tensor.to()` API, for more details:
1380+
https://pytorch.org/docs/stable/generated/torch.Tensor.to.html.
1381+
"""
1382+
super().__init__(keys, allow_missing_keys)
1383+
self.converter = ToDevice(device=device, **kwargs)
1384+
1385+
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
1386+
d = dict(data)
1387+
for key in self.key_iterator(d):
1388+
d[key] = self.converter(d[key])
1389+
return d
1390+
1391+
13571392
IdentityD = IdentityDict = Identityd
13581393
AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd
13591394
AsChannelLastD = AsChannelLastDict = AsChannelLastd
@@ -1389,3 +1424,4 @@ def __call__(self, data) -> Dict[Hashable, np.ndarray]:
13891424
RandLambdaD = RandLambdaDict = RandLambdad
13901425
MapLabelValueD = MapLabelValueDict = MapLabelValued
13911426
IntensityStatsD = IntensityStatsDict = IntensityStatsd
1427+
ToDeviceD = ToDeviceDict = ToDeviced

tests/test_to_device.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import torch
15+
from parameterized import parameterized
16+
17+
from monai.transforms import ToDevice
18+
from tests.utils import skip_if_no_cuda
19+
20+
TEST_CASE_1 = ["cuda:0"]
21+
22+
TEST_CASE_2 = ["cuda"]
23+
24+
TEST_CASE_3 = [torch.device("cpu:0")]
25+
26+
TEST_CASE_4 = ["cpu"]
27+
28+
29+
@skip_if_no_cuda
30+
class TestToDevice(unittest.TestCase):
31+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
32+
def test_value(self, device):
33+
converter = ToDevice(device=device, non_blocking=True)
34+
data = torch.tensor([1, 2, 3, 4])
35+
ret = converter(data)
36+
torch.testing.assert_allclose(ret, data.to(device))
37+
38+
39+
if __name__ == "__main__":
40+
unittest.main()

tests/test_to_deviced.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import torch
15+
16+
from monai.data import CacheDataset, ThreadDataLoader
17+
from monai.transforms import ToDeviced
18+
from tests.utils import skip_if_no_cuda
19+
20+
21+
@skip_if_no_cuda
22+
class TestToDeviced(unittest.TestCase):
23+
def test_value(self):
24+
device = "cuda:0"
25+
data = [{"img": torch.tensor(i)} for i in range(4)]
26+
dataset = CacheDataset(
27+
data=data,
28+
transform=ToDeviced(keys="img", device=device, non_blocking=True),
29+
cache_rate=1.0,
30+
)
31+
dataloader = ThreadDataLoader(dataset=dataset, num_workers=0, batch_size=1)
32+
for i, d in enumerate(dataloader):
33+
torch.testing.assert_allclose(d["img"], torch.tensor([i], device=device))
34+
35+
36+
if __name__ == "__main__":
37+
unittest.main()

0 commit comments

Comments
 (0)