Skip to content

Commit 5b36d02

Browse files
Nic-Mamonai-bot
andauthored
3444 Add DatasetFunc (#3456)
* [DLMED] add dataset generator Signed-off-by: Nic Ma <[email protected]> * [DLMED] add DatasetGenerator Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]> * [MONAI] python code formatting Signed-off-by: monai-bot <[email protected]> * [DLMED] fix wrong test Signed-off-by: Nic Ma <[email protected]> * [DLMED] simplify according to comments Signed-off-by: Nic Ma <[email protected]> * [DLMED] remove return Signed-off-by: Nic Ma <[email protected]> * [DLMED] update rtol for CI Signed-off-by: Nic Ma <[email protected]> Co-authored-by: monai-bot <[email protected]>
1 parent 4b5ad0b commit 5b36d02

File tree

5 files changed

+110
-1
lines changed

5 files changed

+110
-1
lines changed

docs/source/data.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ Generic Interfaces
2121
:members:
2222
:special-members: __next__
2323

24+
`DatasetFunc`
25+
~~~~~~~~~~~~~
26+
.. autoclass:: DatasetFunc
27+
:members:
28+
:special-members: __next__
29+
2430
`ShuffleBuffer`
2531
~~~~~~~~~~~~~~~
2632
.. autoclass:: ShuffleBuffer

monai/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
CacheNTransDataset,
1818
CSVDataset,
1919
Dataset,
20+
DatasetFunc,
2021
LMDBDataset,
2122
NPZDictItemDataset,
2223
PersistentDataset,

monai/data/dataset.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,56 @@ def __getitem__(self, index: Union[int, slice, Sequence[int]]):
9797
return self._transform(index)
9898

9999

100+
class DatasetFunc(Dataset):
101+
"""
102+
Execute function on the input dataset and leverage the output to act as a new Dataset.
103+
It can be used to load / fetch the basic dataset items, like the list of `image, label` paths.
104+
Or chain together to execute more complicated logic, like `partition_dataset`, `resample_datalist`, etc.
105+
The `data` arg of `Dataset` will be applied to the first arg of callable `func`.
106+
Usage example::
107+
108+
data_list = DatasetFunc(
109+
data="path to file",
110+
func=monai.data.load_decathlon_datalist,
111+
data_list_key="validation",
112+
base_dir="path to base dir",
113+
)
114+
# partition dataset for every rank
115+
data_partition = DatasetFunc(
116+
data=data_list,
117+
func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()],
118+
num_partitions=torch.distributed.get_world_size(),
119+
)
120+
dataset = Dataset(data=data_partition, transform=transforms)
121+
122+
Args:
123+
data: input data for the func to process, will apply to `func` as the first arg.
124+
func: callable function to generate dataset items.
125+
kwargs: other arguments for the `func` except for the first arg.
126+
127+
"""
128+
129+
def __init__(self, data: Any, func: Callable, **kwargs) -> None:
130+
super().__init__(data=None, transform=None) # type:ignore
131+
self.src = data
132+
self.func = func
133+
self.kwargs = kwargs
134+
self.reset()
135+
136+
def reset(self, data: Optional[Any] = None, func: Optional[Callable] = None, **kwargs):
137+
"""
138+
Reset the dataset items with specified `func`.
139+
140+
Args:
141+
data: if not None, execute `func` on it, default to `self.src`.
142+
func: if not None, execute the `func` with specified `kwargs`, default to `self.func`.
143+
kwargs: other arguments for the `func` except for the first arg.
144+
145+
"""
146+
src = self.src if data is None else data
147+
self.data = self.func(src, **self.kwargs) if func is None else func(src, **kwargs)
148+
149+
100150
class PersistentDataset(Dataset):
101151
"""
102152
Persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data,

tests/test_dataset_func.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import json
13+
import os
14+
import tempfile
15+
import unittest
16+
17+
from monai.data import Dataset, DatasetFunc, load_decathlon_datalist, partition_dataset
18+
19+
20+
class TestDatasetFunc(unittest.TestCase):
21+
def test_seg_values(self):
22+
with tempfile.TemporaryDirectory() as tempdir:
23+
# prepare test datalist file
24+
test_data = {
25+
"name": "Spleen",
26+
"description": "Spleen Segmentation",
27+
"labels": {"0": "background", "1": "spleen"},
28+
"training": [
29+
{"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz"},
30+
{"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz"},
31+
],
32+
"test": ["spleen_15.nii.gz", "spleen_23.nii.gz"],
33+
}
34+
json_str = json.dumps(test_data)
35+
file_path = os.path.join(tempdir, "test_data.json")
36+
with open(file_path, "w") as json_file:
37+
json_file.write(json_str)
38+
39+
data_list = DatasetFunc(
40+
data=file_path, func=load_decathlon_datalist, data_list_key="training", base_dir=tempdir
41+
)
42+
# partition dataset for train / validation
43+
data_partition = DatasetFunc(
44+
data=data_list, func=lambda x, **kwargs: partition_dataset(x, **kwargs)[0], num_partitions=2
45+
)
46+
dataset = Dataset(data=data_partition, transform=None)
47+
self.assertEqual(dataset[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz"))
48+
self.assertEqual(dataset[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz"))
49+
50+
51+
if __name__ == "__main__":
52+
unittest.main()

tests/test_scale_intensity_range_percentilesd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_scaling(self):
3535
scaler = ScaleIntensityRangePercentilesd(
3636
keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max
3737
)
38-
assert_allclose(p(expected), scaler(data)["img"])
38+
assert_allclose(p(expected), scaler(data)["img"], rtol=1e-4)
3939

4040
def test_relative_scaling(self):
4141
img = self.imt

0 commit comments

Comments
 (0)