@@ -26,6 +26,7 @@ def __init__(
2626 test_size = 0.0 ,
2727 val_size = 0.0 ,
2828 compile = None ,
29+ repeat = False ,
2930 automatic_batching = None ,
3031 num_workers = None ,
3132 pin_memory = None ,
@@ -49,9 +50,13 @@ def __init__(
4950 validation dataset. Default is ``0.0``.
5051 :param bool compile: If ``True``, the model is compiled before training.
5152 Default is ``False``. For Windows users, it is always disabled.
53+ :param bool repeat: Whether to repeat the dataset data in each
54+ condition during training. For further details, see the
55+ :class:`~pina.data.PinaDataModule` class. Default is ``False``.
5256 :param bool automatic_batching: If ``True``, automatic PyTorch batching
53- is performed. Avoid using automatic batching when ``batch_size`` is
54- large. Default is ``False``.
57+ is performed, otherwise the items are retrieved from the dataset
58+ all at once. For further details, see the
59+ :class:`~pina.data.PinaDataModule` class. Default is ``False``.
5560 :param int num_workers: The number of worker threads for data loading.
5661 Default is ``0`` (serial loading).
5762 :param bool pin_memory: Whether to use pinned memory for faster data
@@ -65,12 +70,13 @@ def __init__(
6570 """
6671 # check consistency for init types
6772 self ._check_input_consistency (
68- solver ,
69- train_size ,
70- test_size ,
71- val_size ,
72- automatic_batching ,
73- compile ,
73+ solver = solver ,
74+ train_size = train_size ,
75+ test_size = test_size ,
76+ val_size = val_size ,
77+ repeat = repeat ,
78+ automatic_batching = automatic_batching ,
79+ compile = compile ,
7480 )
7581 pin_memory , num_workers , shuffle , batch_size = (
7682 self ._check_consistency_and_set_defaults (
@@ -110,14 +116,15 @@ def __init__(
110116 self ._move_to_device ()
111117 self .data_module = None
112118 self ._create_datamodule (
113- train_size ,
114- test_size ,
115- val_size ,
116- batch_size ,
117- automatic_batching ,
118- pin_memory ,
119- num_workers ,
120- shuffle ,
119+ train_size = train_size ,
120+ test_size = test_size ,
121+ val_size = val_size ,
122+ batch_size = batch_size ,
123+ repeat = repeat ,
124+ automatic_batching = automatic_batching ,
125+ pin_memory = pin_memory ,
126+ num_workers = num_workers ,
127+ shuffle = shuffle ,
121128 )
122129
123130 # logging
@@ -151,6 +158,7 @@ def _create_datamodule(
151158 test_size ,
152159 val_size ,
153160 batch_size ,
161+ repeat ,
154162 automatic_batching ,
155163 pin_memory ,
156164 num_workers ,
@@ -169,6 +177,8 @@ def _create_datamodule(
169177 :param float val_size: The percentage of elements to include in the
170178 validation dataset.
171179 :param int batch_size: The number of samples per batch to load.
180+ :param bool repeat: Whether to repeat the dataset data in each
181+ condition during training.
172182 :param bool automatic_batching: Whether to perform automatic batching
173183 with PyTorch. If ``True``, automatic PyTorch batching
174184 is performed, which consists of extracting one element at a time
@@ -206,6 +216,7 @@ def _create_datamodule(
206216 test_size = test_size ,
207217 val_size = val_size ,
208218 batch_size = batch_size ,
219+ repeat = repeat ,
209220 automatic_batching = automatic_batching ,
210221 num_workers = num_workers ,
211222 pin_memory = pin_memory ,
@@ -253,7 +264,13 @@ def solver(self, solver):
253264
254265 @staticmethod
255266 def _check_input_consistency (
256- solver , train_size , test_size , val_size , automatic_batching , compile
267+ solver ,
268+ train_size ,
269+ test_size ,
270+ val_size ,
271+ repeat ,
272+ automatic_batching ,
273+ compile ,
257274 ):
258275 """
259276 Verifies the consistency of the parameters for the solver configuration.
@@ -265,6 +282,8 @@ def _check_input_consistency(
265282 test dataset.
266283 :param float val_size: The percentage of elements to include in the
267284 validation dataset.
285+ :param bool repeat: Whether to repeat the dataset data in each
286+ condition during training.
268287 :param bool automatic_batching: Whether to perform automatic batching
269288 with PyTorch.
270289 :param bool compile: If ``True``, the model is compiled before training.
@@ -274,6 +293,7 @@ def _check_input_consistency(
274293 check_consistency (train_size , float )
275294 check_consistency (test_size , float )
276295 check_consistency (val_size , float )
296+ check_consistency (repeat , bool )
277297 if automatic_batching is not None :
278298 check_consistency (automatic_batching , bool )
279299 if compile is not None :
0 commit comments