Skip to content

Commit 070b513

Browse files
FilippoOlivondem0
authored andcommitted
Minor fix
1 parent 1b2154d commit 070b513

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

pina/data/data_module.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,9 @@ def __init__(
8181
:param dict max_conditions_lengths: ``dict`` containing the maximum
8282
number of data points to consider in a single batch for
8383
each condition.
84-
:param bool automatic_batching: Whether to enable automatic batching.
85-
If ``True``, automatic PyTorch batching
86-
is performed, which consists of extracting one element at a time
87-
from the dataset and collating them into a batch. This is useful
88-
when the dataset is too large to fit into memory. On the other hand,
89-
if ``False``, the items are retrieved from the dataset all at once
90-
avoind the overhead of collating them into a batch and reducing the
91-
__getitem__ calls to the dataset. This is useful when the dataset
92-
fits into memory. Avoid using automatic batching when ``batch_size``
93-
is large. Default is ``False``.
84+
:param bool automatic_batching: Whether automatic PyTorch batching is
85+
enabled or not. For more information, see the
86+
:class:`~pina.data.data_module.PinaDataModule` class.
9487
:param PinaDataset dataset: The dataset where the data is stored.
9588
"""
9689

@@ -294,9 +287,9 @@ def __init__(
294287
when the dataset is too large to fit into memory. On the other hand,
295288
if ``False``, the items are retrieved from the dataset all at once
296289
avoind the overhead of collating them into a batch and reducing the
297-
__getitem__ calls to the dataset. This is useful when the dataset
298-
fits into memory. Avoid using automatic batching when ``batch_size``
299-
is large. Default is ``False``.
290+
``__getitem__`` calls to the dataset. This is useful when the
291+
dataset fits into memory. Avoid using automatic batching when
292+
``batch_size`` is large. Default is ``False``.
300293
:param int num_workers: Number of worker threads for data loading.
301294
Default ``0`` (serial loading).
302295
:param bool pin_memory: Whether to use pinned memory for faster data

pina/trainer.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
test_size=0.0,
2727
val_size=0.0,
2828
compile=None,
29-
repeat=False,
29+
repeat=None,
3030
automatic_batching=None,
3131
num_workers=None,
3232
pin_memory=None,
@@ -52,11 +52,13 @@ def __init__(
5252
Default is ``False``. For Windows users, it is always disabled.
5353
:param bool repeat: Whether to repeat the dataset data in each
5454
condition during training. For further details, see the
55-
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
55+
:class:`~pina.data.data_module.PinaDataModule` class. Default is
56+
``False``.
5657
:param bool automatic_batching: If ``True``, automatic PyTorch batching
5758
is performed, otherwise the items are retrieved from the dataset
5859
all at once. For further details, see the
59-
:class:`~pina.data.PinaDataModule` class. Default is ``False``.
60+
:class:`~pina.data.data_module.PinaDataModule` class. Default is
61+
``False``.
6062
:param int num_workers: The number of worker threads for data loading.
6163
Default is ``0`` (serial loading).
6264
:param bool pin_memory: Whether to use pinned memory for faster data
@@ -105,7 +107,9 @@ def __init__(
105107
if compile is None or sys.platform == "win32":
106108
compile = False
107109

108-
self.automatic_batching = (
110+
repeat = repeat if repeat is not None else False
111+
112+
automatic_batching = (
109113
automatic_batching if automatic_batching is not None else False
110114
)
111115

@@ -180,15 +184,7 @@ def _create_datamodule(
180184
:param bool repeat: Whether to repeat the dataset data in each
181185
condition during training.
182186
:param bool automatic_batching: Whether to perform automatic batching
183-
with PyTorch. If ``True``, automatic PyTorch batching
184-
is performed, which consists of extracting one element at a time
185-
from the dataset and collating them into a batch. This is useful
186-
when the dataset is too large to fit into memory. On the other hand,
187-
if ``False``, the items are retrieved from the dataset all at once
188-
avoind the overhead of collating them into a batch and reducing the
189-
__getitem__ calls to the dataset. This is useful when the dataset
190-
fits into memory. Avoid using automatic batching when ``batch_size``
191-
is large. Default is ``False``.
187+
with PyTorch.
192188
:param bool pin_memory: Whether to use pinned memory for faster data
193189
transfer to GPU.
194190
:param int num_workers: The number of worker threads for data loading.
@@ -293,7 +289,8 @@ def _check_input_consistency(
293289
check_consistency(train_size, float)
294290
check_consistency(test_size, float)
295291
check_consistency(val_size, float)
296-
check_consistency(repeat, bool)
292+
if repeat is not None:
293+
check_consistency(repeat, bool)
297294
if automatic_batching is not None:
298295
check_consistency(automatic_batching, bool)
299296
if compile is not None:

0 commit comments

Comments
 (0)