Skip to content

Commit f3f0ceb

Browse files
committed
update trainer
1 parent 0d97e6a commit f3f0ceb

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

pina/trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
automatic_batching=None,
2323
num_workers=None,
2424
pin_memory=None,
25+
shuffle=None,
2526
**kwargs,
2627
):
2728
"""
@@ -53,6 +54,8 @@ def __init__(
5354
:type num_workers: int
5455
:param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
5556
:type pin_memory: bool
57+
:param shuffle: Whether to shuffle the data for training. (Default False)
58+
:type pin_memory: bool
5659
5760
:Keyword Arguments:
5861
The additional keyword arguments specify the training setup
@@ -77,6 +80,10 @@ def __init__(
7780
check_consistency(pin_memory, int)
7881
else:
7982
num_workers = 0
83+
if shuffle is not None:
84+
check_consistency(shuffle, bool)
85+
else:
86+
shuffle = False
8087
if train_size + test_size + val_size + predict_size > 1:
8188
raise ValueError(
8289
"train_size, test_size, val_size and predict_size "
@@ -131,6 +138,7 @@ def __init__(
131138
automatic_batching,
132139
pin_memory,
133140
num_workers,
141+
shuffle
134142
)
135143

136144
# logging
@@ -166,6 +174,7 @@ def _create_datamodule(
166174
automatic_batching,
167175
pin_memory,
168176
num_workers,
177+
shuffle
169178
):
170179
"""
171180
This method is used here because is resampling is needed
@@ -196,6 +205,7 @@ def _create_datamodule(
196205
automatic_batching=automatic_batching,
197206
num_workers=num_workers,
198207
pin_memory=pin_memory,
208+
shuffle=shuffle
199209
)
200210

201211
def train(self, **kwargs):

0 commit comments

Comments
 (0)