Skip to content

Commit ef29f0a

Browse files
authored
Bug fix and add additional tests for Dataset and DataModule (#517)
1 parent 03ef90c commit ef29f0a

File tree

4 files changed

+143
-16
lines changed

4 files changed

+143
-16
lines changed

pina/data/data_module.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,11 @@ class PinaSampler:
217217
parameter and the environment in which the code is running.
218218
"""
219219

220-
def __new__(cls, dataset, shuffle):
220+
def __new__(cls, dataset):
221221
"""
222222
Instantiate and initialize the sampler.
223223
224224
:param PinaDataset dataset: The dataset from which to sample.
225-
:param bool shuffle: Whether to shuffle the dataset.
226225
:return: The sampler instance.
227226
:rtype: :class:`torch.utils.data.Sampler`
228227
"""
@@ -231,12 +230,9 @@ def __new__(cls, dataset, shuffle):
231230
torch.distributed.is_available()
232231
and torch.distributed.is_initialized()
233232
):
234-
sampler = DistributedSampler(dataset, shuffle=shuffle)
233+
sampler = DistributedSampler(dataset)
235234
else:
236-
if shuffle:
237-
sampler = RandomSampler(dataset)
238-
else:
239-
sampler = SequentialSampler(dataset)
235+
sampler = SequentialSampler(dataset)
240236
return sampler
241237

242238

@@ -496,8 +492,6 @@ def _create_dataloader(self, split, dataset):
496492
:return: The dataloader for the given split.
497493
:rtype: torch.utils.data.DataLoader
498494
"""
499-
500-
shuffle = self.shuffle if split == "train" else False
501495
# Suppress the warning about num_workers.
502496
# In many cases, especially for PINNs,
503497
# serial data loading can outperform parallel data loading.
@@ -511,7 +505,7 @@ def _create_dataloader(self, split, dataset):
511505
)
512506
# Use custom batching (good if batch size is large)
513507
if self.batch_size is not None:
514-
sampler = PinaSampler(dataset, shuffle)
508+
sampler = PinaSampler(dataset)
515509
if self.automatic_batching:
516510
collate = Collator(
517511
self.find_max_conditions_lengths(split),

pina/data/dataset.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,15 @@ def get_all_data(self):
167167
:return: A dictionary containing all the data in the dataset.
168168
:rtype: dict
169169
"""
170-
171-
index = list(range(len(self)))
172-
return self.fetch_from_idx_list(index)
170+
to_return_dict = {}
171+
for condition, data in self.conditions_dict.items():
172+
len_condition = len(
173+
data["input"]
174+
) # Length of the current condition
175+
to_return_dict[condition] = self._retrive_data(
176+
data, list(range(len_condition))
177+
) # Retrieve the data from the current condition
178+
return to_return_dict
173179

174180
def fetch_from_idx_list(self, idx):
175181
"""
@@ -306,3 +312,13 @@ def _retrive_data(self, data, idx_list):
306312
)
307313
for k, v in data.items()
308314
}
315+
316+
@property
317+
def input(self):
318+
"""
319+
Return the input data for the dataset.
320+
321+
:return: Dictionary containing the input points.
322+
:rtype: dict
323+
"""
324+
return {k: v["input"] for k, v in self.conditions_dict.items()}

tests/test_data/test_data_module.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,94 @@ def test_dataloader_labels(input_, output_, automatic_batching):
238238
assert data["data"]["input"].labels == ["u", "v", "w"]
239239
assert isinstance(data["data"]["target"], torch.Tensor)
240240
assert data["data"]["target"].labels == ["u", "v", "w"]
241+
242+
243+
def test_get_all_data():
244+
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
245+
target = input
246+
247+
problem = SupervisedProblem(input, target)
248+
datamodule = PinaDataModule(
249+
problem,
250+
train_size=0.7,
251+
test_size=0.2,
252+
val_size=0.1,
253+
batch_size=64,
254+
shuffle=False,
255+
repeat=False,
256+
automatic_batching=None,
257+
num_workers=0,
258+
pin_memory=False,
259+
)
260+
datamodule.setup("fit")
261+
datamodule.setup("test")
262+
assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700
263+
assert torch.isclose(
264+
datamodule.train_dataset.get_all_data()["data"]["input"], input[:700]
265+
).all()
266+
assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100
267+
assert torch.isclose(
268+
datamodule.val_dataset.get_all_data()["data"]["input"], input[900:]
269+
).all()
270+
assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200
271+
assert torch.isclose(
272+
datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900]
273+
).all()
274+
275+
276+
def test_input_propery_tensor():
277+
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
278+
target = input
279+
280+
problem = SupervisedProblem(input, target)
281+
datamodule = PinaDataModule(
282+
problem,
283+
train_size=0.7,
284+
test_size=0.2,
285+
val_size=0.1,
286+
batch_size=64,
287+
shuffle=False,
288+
repeat=False,
289+
automatic_batching=None,
290+
num_workers=0,
291+
pin_memory=False,
292+
)
293+
datamodule.setup("fit")
294+
datamodule.setup("test")
295+
input_ = datamodule.input
296+
assert isinstance(input_, dict)
297+
assert isinstance(input_["train"], dict)
298+
assert isinstance(input_["val"], dict)
299+
assert isinstance(input_["test"], dict)
300+
assert torch.isclose(input_["train"]["data"], input[:700]).all()
301+
assert torch.isclose(input_["val"]["data"], input[900:]).all()
302+
assert torch.isclose(input_["test"]["data"], input[700:900]).all()
303+
304+
305+
def test_input_propery_graph():
306+
problem = SupervisedProblem(input_graph, output_graph)
307+
datamodule = PinaDataModule(
308+
problem,
309+
train_size=0.7,
310+
test_size=0.2,
311+
val_size=0.1,
312+
batch_size=64,
313+
shuffle=False,
314+
repeat=False,
315+
automatic_batching=None,
316+
num_workers=0,
317+
pin_memory=False,
318+
)
319+
datamodule.setup("fit")
320+
datamodule.setup("test")
321+
input_ = datamodule.input
322+
assert isinstance(input_, dict)
323+
assert isinstance(input_["train"], dict)
324+
assert isinstance(input_["val"], dict)
325+
assert isinstance(input_["test"], dict)
326+
assert isinstance(input_["train"]["data"], list)
327+
assert isinstance(input_["val"]["data"], list)
328+
assert isinstance(input_["test"]["data"], list)
329+
assert len(input_["train"]["data"]) == 70
330+
assert len(input_["val"]["data"]) == 10
331+
assert len(input_["test"]["data"]) == 20

tests/test_data/test_graph_dataset.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
max_conditions_lengths_single = {"data": 100}
3232

3333
# Problem with multiple conditions
34-
conditions_dict_single_multi = {
34+
conditions_dict_multi = {
3535
"data_1": {
3636
"input": input_,
3737
"target": output_,
@@ -49,7 +49,7 @@
4949
"conditions_dict, max_conditions_lengths",
5050
[
5151
(conditions_dict_single, max_conditions_lengths_single),
52-
(conditions_dict_single_multi, max_conditions_lengths_multi),
52+
(conditions_dict_multi, max_conditions_lengths_multi),
5353
],
5454
)
5555
def test_constructor(conditions_dict, max_conditions_lengths):
@@ -66,7 +66,7 @@ def test_constructor(conditions_dict, max_conditions_lengths):
6666
"conditions_dict, max_conditions_lengths",
6767
[
6868
(conditions_dict_single, max_conditions_lengths_single),
69-
(conditions_dict_single_multi, max_conditions_lengths_multi),
69+
(conditions_dict_multi, max_conditions_lengths_multi),
7070
],
7171
)
7272
def test_getitem(conditions_dict, max_conditions_lengths):
@@ -110,3 +110,29 @@ def test_getitem(conditions_dict, max_conditions_lengths):
110110
]
111111
)
112112
assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
113+
114+
115+
def test_input_single_condition():
116+
dataset = PinaDatasetFactory(
117+
conditions_dict_single,
118+
max_conditions_lengths=max_conditions_lengths_single,
119+
automatic_batching=True,
120+
)
121+
input_ = dataset.input
122+
assert isinstance(input_, dict)
123+
assert isinstance(input_["data"], list)
124+
assert all([isinstance(d, Data) for d in input_["data"]])
125+
126+
127+
def test_input_multi_condition():
128+
dataset = PinaDatasetFactory(
129+
conditions_dict_multi,
130+
max_conditions_lengths=max_conditions_lengths_multi,
131+
automatic_batching=True,
132+
)
133+
input_ = dataset.input
134+
assert isinstance(input_, dict)
135+
assert isinstance(input_["data_1"], list)
136+
assert all([isinstance(d, Data) for d in input_["data_1"]])
137+
assert isinstance(input_["data_2"], list)
138+
assert all([isinstance(d, Data) for d in input_["data_2"]])

0 commit comments

Comments
 (0)