@@ -24,7 +24,7 @@ def run(cfg):
2424 Full driver/trainer configuration
2525 """
2626 # Process the configuration to set up the driver world
27- distributed , world_size = process_world (** cfg )
27+ distributed , world_size , torch_sharing = process_world (** cfg )
2828
2929 # Launch the training/inference process
3030 if not distributed :
@@ -39,7 +39,9 @@ def run(cfg):
3939
4040 # Launch the distributed training process
4141 torch .multiprocessing .spawn (
42- train_single , args = (cfg , distributed , world_size ), nprocs = world_size
42+ train_single ,
43+ args = (cfg , distributed , world_size , torch_sharing ),
44+ nprocs = world_size ,
4345 )
4446
4547
@@ -58,7 +60,7 @@ def run_single(cfg):
5860 inference_single (cfg )
5961
6062
61- def train_single (rank , cfg , distributed = False , world_size = None ):
63+ def train_single (rank , cfg , distributed = False , world_size = None , torch_sharing = None ):
6264 """Train a model in a single process.
6365
6466 Parameters
@@ -71,6 +73,8 @@ def train_single(rank, cfg, distributed=False, world_size=None):
7173 If `True`, distribute the training process
7274 world_size : int, optional
7375 Number of devices to use in the distributed training process
76+ torch_sharing : str or None, optional
77+ File sharing strategy for torch distributed training
7478 """
7579 # Training always requires torch
7680 if not TORCH_AVAILABLE :
@@ -79,6 +83,10 @@ def train_single(rank, cfg, distributed=False, world_size=None):
7983 "Install with: pip install spine-ml[model]"
8084 )
8185
86+ # Set the torch sharing strategy if needed
87+ if distributed and torch_sharing is not None :
88+ torch .multiprocessing .set_sharing_strategy (torch_sharing )
89+
8290 # If distributed, setup the process group
8391 if distributed :
8492 setup_ddp (rank , world_size )
@@ -95,8 +103,7 @@ def train_single(rank, cfg, distributed=False, world_size=None):
95103
96104
97105def inference_single (cfg ):
98- """
99- Execute a model in inference mode in a single process
106+ """Execute a model in inference mode in a single process.
100107
101108 Parameters
102109 ----------
@@ -121,7 +128,7 @@ def inference_single(cfg):
121128
122129 # Loop over the weights, run the inference loop
123130 for weight in weights :
124- if weight is not None and not preloaded :
131+ if driver . model is not None and weight is not None and not preloaded :
125132 driver .model .load_weights (weight )
126133 driver .initialize_log ()
127134
@@ -145,6 +152,8 @@ def process_world(base, **kwargs):
145152 If `True`, distribute the training process
146153 world_size : int
147154 Number of devices to use in the distributed training process
155+ torch_sharing : str or None
156+ File sharing strategy for torch distributed training
148157 """
149158 # Set the verbosity of the logger
150159 verbosity = base .get ("verbosity" , "info" )
@@ -159,7 +168,14 @@ def process_world(base, **kwargs):
159168 world_size < 2 or distributed
160169 ), "Cannot run process on multiple GPUs without distributing it."
161170
162- return distributed , world_size
171+ # If distributed, check what the file sharing strategy is
172+ torch_sharing = base .get ("torch_sharing_strategy" , None )
173+ assert not torch_sharing or torch_sharing in ("file_system" , "file_descriptor" ), (
174+ "torch_sharing_strategy must be one of: "
175+ "'file_system', 'file_descriptor', or None"
176+ )
177+
178+ return distributed , world_size , torch_sharing
163179
164180
165181def setup_ddp (rank , world_size , backend = "nccl" ):
0 commit comments