File tree Expand file tree Collapse file tree 1 file changed +10
-0
lines changed
Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -22,6 +22,7 @@ def __init__(
2222 automatic_batching = None ,
2323 num_workers = None ,
2424 pin_memory = None ,
25+ shuffle = None ,
2526 ** kwargs ,
2627 ):
2728 """
@@ -53,6 +54,8 @@ def __init__(
5354 :type num_workers: int
5455 :param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
5556 :type pin_memory: bool
57+ :param shuffle: Whether to shuffle the data for training. (Default False)
58+ :type pin_memory: bool
5659
5760 :Keyword Arguments:
5861 The additional keyword arguments specify the training setup
@@ -77,6 +80,10 @@ def __init__(
7780 check_consistency (pin_memory , int )
7881 else :
7982 num_workers = 0
83+ if shuffle is not None :
84+ check_consistency (shuffle , bool )
85+ else :
86+ shuffle = False
8087 if train_size + test_size + val_size + predict_size > 1 :
8188 raise ValueError (
8289 "train_size, test_size, val_size and predict_size "
@@ -131,6 +138,7 @@ def __init__(
131138 automatic_batching ,
132139 pin_memory ,
133140 num_workers ,
141+ shuffle
134142 )
135143
136144 # logging
@@ -166,6 +174,7 @@ def _create_datamodule(
166174 automatic_batching ,
167175 pin_memory ,
168176 num_workers ,
177+ shuffle
169178 ):
170179 """
171180 This method is used here because is resampling is needed
@@ -196,6 +205,7 @@ def _create_datamodule(
196205 automatic_batching = automatic_batching ,
197206 num_workers = num_workers ,
198207 pin_memory = pin_memory ,
208+ shuffle = shuffle
199209 )
200210
201211 def train (self , ** kwargs ):
You can’t perform that action at this time.
0 commit comments