@@ -23,7 +23,7 @@ def make_dir(dirPath):
2323'''
2424Reads csv and makes both train and validation data loaders from it
2525'''
26- def get_train_val_loaders (loader_dir , data_csv , batch_size = 1 , down_factor = 1 , down_dir = None , train_split = 0.80 ):
26+ def get_train_val_loaders (loader_dir , data_csv , batch_size = 1 , down_factor = 1 , down_dir = None , train_split = 0.80 , num_workers = 0 ):
2727 sw_message ("Creating training and validation torch loaders:" )
2828 make_dir (loader_dir )
2929 images , scores , models , prefixes = get_all_train_data (loader_dir , data_csv , down_factor , down_dir )
@@ -41,7 +41,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow
4141 train_data ,
4242 batch_size = batch_size ,
4343 shuffle = True ,
44- num_workers = 8 ,
44+ num_workers = num_workers ,
4545 pin_memory = torch .cuda .is_available ()
4646 )
4747 train_path = loader_dir + 'train'
@@ -51,7 +51,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow
5151 val_data ,
5252 batch_size = 1 ,
5353 shuffle = True ,
54- num_workers = 8 ,
54+ num_workers = num_workers ,
5555 pin_memory = torch .cuda .is_available ()
5656 )
5757 val_path = loader_dir + 'validation'
@@ -62,7 +62,7 @@ def get_train_val_loaders(loader_dir, data_csv, batch_size=1, down_factor=1, dow
6262'''
6363Reads csv and makes just train data loaders
6464'''
65- def get_train_loader (loader_dir , data_csv , batch_size = 1 , down_factor = 1 , down_dir = None , train_split = 0.80 ):
65+ def get_train_loader (loader_dir , data_csv , batch_size = 1 , down_factor = 1 , down_dir = None , train_split = 0.80 , num_workers = 0 ):
6666 sw_message ("Creating training torch loader..." )
6767 # Get data
6868 make_dir (loader_dir )
@@ -74,7 +74,7 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir
7474 train_data ,
7575 batch_size = batch_size ,
7676 shuffle = True ,
77- num_workers = 8 ,
77+ num_workers = num_workers ,
7878 pin_memory = torch .cuda .is_available ()
7979 )
8080 train_path = loader_dir + 'train'
@@ -85,7 +85,7 @@ def get_train_loader(loader_dir, data_csv, batch_size=1, down_factor=1, down_dir
8585'''
8686Makes validation data loader
8787'''
88- def get_validation_loader (loader_dir , val_img_list , val_particles , down_factor = 1 , down_dir = None ):
88+ def get_validation_loader (loader_dir , val_img_list , val_particles , down_factor = 1 , down_dir = None , num_workers = 0 ):
8989 sw_message ("Creating validation torch loader:" )
9090 # Get data
9191 image_paths = []
@@ -113,7 +113,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1
113113 val_data ,
114114 batch_size = 1 ,
115115 shuffle = False ,
116- num_workers = 8 ,
116+ num_workers = num_workers ,
117117 pin_memory = torch .cuda .is_available ()
118118 )
119119 val_path = loader_dir + 'validation'
@@ -124,7 +124,7 @@ def get_validation_loader(loader_dir, val_img_list, val_particles, down_factor=1
124124'''
125125Makes test data loader
126126'''
127- def get_test_loader (loader_dir , test_img_list , down_factor = 1 , down_dir = None ):
127+ def get_test_loader (loader_dir , test_img_list , down_factor = 1 , down_dir = None , num_workers = 0 ):
128128 sw_message ("Creating test torch loader..." )
129129 # get data
130130 image_paths = []
@@ -152,7 +152,7 @@ def get_test_loader(loader_dir, test_img_list, down_factor=1, down_dir=None):
152152 test_data ,
153153 batch_size = 1 ,
154154 shuffle = False ,
155- num_workers = 8 ,
155+ num_workers = num_workers ,
156156 pin_memory = torch .cuda .is_available ()
157157 )
158158 test_path = loader_dir + 'test'
0 commit comments