1- import torch
1+ import multiprocessing as mp
22import pytorch_lightning as pl
3+ import torch
34from .dataset import HotppDataset , ShuffledDistributedDataset , DEFAULT_PARALLELIZM
45
56
67def pop_loader_params (params ):
78 loader_params = {}
89 for key in ["seed" , "num_workers" , "batch_size" , "cache_size" , "parallelize" , "drop_last" , "prefetch_factor" ,
9- "persistent_workers" , "multiprocessing_context" ]:
10+ "pin_memory" , " persistent_workers" , "multiprocessing_context" ]:
1011 if key in params :
1112 loader_params [key ] = params .pop (key )
1213 return loader_params
1314
1415
16+ def get_default_loader_params ():
17+ default_loader_params = {
18+ "persistent_workers" : True ,
19+ "pin_memory" : torch .cuda .is_available ()
20+ }
21+ available_contexts = mp .get_all_start_methods ()
22+ for context in ["forkserver" , "spawn" , "fork" ]:
23+ if context in available_contexts :
24+ default_loader_params ["multiprocessing_context" ] = context
25+ break
26+ return default_loader_params
27+
28+
1529class HotppSampler (torch .utils .data .DistributedSampler ):
1630 def __init__ (self , dataset ):
1731 # Skip super init.
@@ -106,10 +120,8 @@ def splits(self):
106120 def train_dataloader (self , rank = None , world_size = None ):
107121 rank = self .trainer .global_rank if rank is None else rank
108122 world_size = self .trainer .world_size if world_size is None else world_size
109- loader_params = {"drop_last" : True ,
110- "multiprocessing_context" : "spawn" ,
111- "persistent_workers" : True ,
112- "pin_memory" : torch .cuda .is_available ()}
123+ loader_params = get_default_loader_params ()
124+ loader_params .update ({"drop_last" : True })
113125 loader_params .update (self .train_loader_params )
114126 dataset = ShuffledDistributedDataset (self .train_data , rank = rank , world_size = world_size ,
115127 cache_size = loader_params .pop ("cache_size" , 4096 ),
@@ -127,9 +139,7 @@ def train_dataloader(self, rank=None, world_size=None):
127139 def val_dataloader (self , rank = None , world_size = None ):
128140 rank = self .trainer .global_rank if rank is None else rank
129141 world_size = self .trainer .world_size if world_size is None else world_size
130- loader_params = {"multiprocessing_context" : "spawn" ,
131- "persistent_workers" : True ,
132- "pin_memory" : torch .cuda .is_available ()}
142+ loader_params = get_default_loader_params ()
133143 loader_params .update (self .val_loader_params )
134144 dataset = ShuffledDistributedDataset (self .val_data , rank = rank , world_size = world_size ,
135145 parallelize = loader_params .pop ("parallelize" , DEFAULT_PARALLELIZM )) # Disable shuffle.
@@ -143,9 +153,7 @@ def val_dataloader(self, rank=None, world_size=None):
143153 def test_dataloader (self , rank = None , world_size = None ):
144154 rank = self .trainer .global_rank if rank is None else rank
145155 world_size = self .trainer .world_size if world_size is None else world_size
146- loader_params = {"multiprocessing_context" : "spawn" ,
147- "persistent_workers" : True ,
148- "pin_memory" : torch .cuda .is_available ()}
156+ loader_params = get_default_loader_params ()
149157 loader_params .update (self .test_loader_params )
150158 dataset = ShuffledDistributedDataset (self .test_data , rank = rank , world_size = world_size ,
151159 parallelize = loader_params .pop ("parallelize" , DEFAULT_PARALLELIZM )) # Disable shuffle.
0 commit comments