Skip to content

Commit 1b2154d

Browse files
FilippoOlivondem0
authored andcommitted
Add docstring for repeat in DataModule
1 parent 7f89c4f commit 1b2154d

File tree

2 files changed

+51
-21
lines changed

2 files changed

+51
-21
lines changed

pina/data/data_module.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,20 @@ def __init__(
283283
Default is ``None``.
284284
:param bool shuffle: Whether to shuffle the dataset before splitting.
285285
Default ``True``.
286-
:param bool repeat: Whether to repeat the dataset indefinitely.
287-
Default ``False``.
288-
:param automatic_batching: Whether to enable automatic batching.
289-
Default ``False``.
286+
:param bool repeat: If ``True``, in case of batch size larger than the
287+
number of elements in a specific condition, the elements are
288+
repeated until the batch size is reached. If ``False``, the number
289+
of elements in the batch is the minimum between the batch size and
290+
the number of elements in the condition. Default is ``False``.
291+
:param automatic_batching: If ``True``, automatic PyTorch batching
292+
is performed, which consists of extracting one element at a time
293+
from the dataset and collating them into a batch. This is useful
294+
when the dataset is too large to fit into memory. On the other hand,
295+
if ``False``, the items are retrieved from the dataset all at once
296+
avoind the overhead of collating them into a batch and reducing the
297+
__getitem__ calls to the dataset. This is useful when the dataset
298+
fits into memory. Avoid using automatic batching when ``batch_size``
299+
is large. Default is ``False``.
290300
:param int num_workers: Number of worker threads for data loading.
291301
Default ``0`` (serial loading).
292302
:param bool pin_memory: Whether to use pinned memory for faster data

pina/trainer.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
test_size=0.0,
2727
val_size=0.0,
2828
compile=None,
29+
repeat=False,
2930
automatic_batching=None,
3031
num_workers=None,
3132
pin_memory=None,
@@ -49,9 +50,13 @@ def __init__(
4950
validation dataset. Default is ``0.0``.
5051
:param bool compile: If ``True``, the model is compiled before training.
5152
Default is ``False``. For Windows users, it is always disabled.
53+
:param bool repeat: Whether to repeat the dataset data in each
54+
condition during training. For further details, see the
55+
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
5256
:param bool automatic_batching: If ``True``, automatic PyTorch batching
53-
is performed. Avoid using automatic batching when ``batch_size`` is
54-
large. Default is ``False``.
57+
is performed, otherwise the items are retrieved from the dataset
58+
all at once. For further details, see the
59+
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
5560
:param int num_workers: The number of worker threads for data loading.
5661
Default is ``0`` (serial loading).
5762
:param bool pin_memory: Whether to use pinned memory for faster data
@@ -65,12 +70,13 @@ def __init__(
6570
"""
6671
# check consistency for init types
6772
self._check_input_consistency(
68-
solver,
69-
train_size,
70-
test_size,
71-
val_size,
72-
automatic_batching,
73-
compile,
73+
solver=solver,
74+
train_size=train_size,
75+
test_size=test_size,
76+
val_size=val_size,
77+
repeat=repeat,
78+
automatic_batching=automatic_batching,
79+
compile=compile,
7480
)
7581
pin_memory, num_workers, shuffle, batch_size = (
7682
self._check_consistency_and_set_defaults(
@@ -110,14 +116,15 @@ def __init__(
110116
self._move_to_device()
111117
self.data_module = None
112118
self._create_datamodule(
113-
train_size,
114-
test_size,
115-
val_size,
116-
batch_size,
117-
automatic_batching,
118-
pin_memory,
119-
num_workers,
120-
shuffle,
119+
train_size=train_size,
120+
test_size=test_size,
121+
val_size=val_size,
122+
batch_size=batch_size,
123+
repeat=repeat,
124+
automatic_batching=automatic_batching,
125+
pin_memory=pin_memory,
126+
num_workers=num_workers,
127+
shuffle=shuffle,
121128
)
122129

123130
# logging
@@ -151,6 +158,7 @@ def _create_datamodule(
151158
test_size,
152159
val_size,
153160
batch_size,
161+
repeat,
154162
automatic_batching,
155163
pin_memory,
156164
num_workers,
@@ -169,6 +177,8 @@ def _create_datamodule(
169177
:param float val_size: The percentage of elements to include in the
170178
validation dataset.
171179
:param int batch_size: The number of samples per batch to load.
180+
:param bool repeat: Whether to repeat the dataset data in each
181+
condition during training.
172182
:param bool automatic_batching: Whether to perform automatic batching
173183
with PyTorch. If ``True``, automatic PyTorch batching
174184
is performed, which consists of extracting one element at a time
@@ -206,6 +216,7 @@ def _create_datamodule(
206216
test_size=test_size,
207217
val_size=val_size,
208218
batch_size=batch_size,
219+
repeat=repeat,
209220
automatic_batching=automatic_batching,
210221
num_workers=num_workers,
211222
pin_memory=pin_memory,
@@ -253,7 +264,13 @@ def solver(self, solver):
253264

254265
@staticmethod
255266
def _check_input_consistency(
256-
solver, train_size, test_size, val_size, automatic_batching, compile
267+
solver,
268+
train_size,
269+
test_size,
270+
val_size,
271+
repeat,
272+
automatic_batching,
273+
compile,
257274
):
258275
"""
259276
Verifies the consistency of the parameters for the solver configuration.
@@ -265,6 +282,8 @@ def _check_input_consistency(
265282
test dataset.
266283
:param float val_size: The percentage of elements to include in the
267284
validation dataset.
285+
:param bool repeat: Whether to repeat the dataset data in each
286+
condition during training.
268287
:param bool automatic_batching: Whether to perform automatic batching
269288
with PyTorch.
270289
:param bool compile: If ``True``, the model is compiled before training.
@@ -274,6 +293,7 @@ def _check_input_consistency(
274293
check_consistency(train_size, float)
275294
check_consistency(test_size, float)
276295
check_consistency(val_size, float)
296+
check_consistency(repeat, bool)
277297
if automatic_batching is not None:
278298
check_consistency(automatic_batching, bool)
279299
if compile is not None:

0 commit comments

Comments
 (0)