Skip to content

Commit 15a5851

Browse files
Add option to change the file sharing strategy in DDP
1 parent c5e9b95 commit 15a5851

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

bin/larcv_inject_run_number.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def main(source, source_list, dest, overwrite, run_number, offset, suffix):
148148
# Finalize
149149
io.finalize()
150150

151-
# If needed move the output file to where the
151+
# If needed move the output file to where the input file is
152152
if overwrite:
153153
os.rename(out_path, file_path)
154154

@@ -182,7 +182,8 @@ def main(source, source_list, dest, overwrite, run_number, offset, suffix):
182182
group = parser.add_mutually_exclusive_group(required=True)
183183
group.add_argument(
184184
"--run-number",
185-
help="Run number to assign to every input file. If -1, each file is assigned a unique run number",
185+
help="Run number to assign to every input file. If -1, each file is "
186+
"assigned a unique run number",
186187
type=int,
187188
)
188189
group.add_argument(

src/spine/driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def initialize_base(
281281
rank=None,
282282
log_step=1,
283283
distributed=False,
284+
torch_sharing_strategy=None,
284285
split_output=False,
285286
train=None,
286287
verbosity="info",

src/spine/main.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def run(cfg):
2424
Full driver/trainer configuration
2525
"""
2626
# Process the configuration to set up the driver world
27-
distributed, world_size = process_world(**cfg)
27+
distributed, world_size, torch_sharing = process_world(**cfg)
2828

2929
# Launch the training/inference process
3030
if not distributed:
@@ -39,7 +39,9 @@ def run(cfg):
3939

4040
# Launch the distributed training process
4141
torch.multiprocessing.spawn(
42-
train_single, args=(cfg, distributed, world_size), nprocs=world_size
42+
train_single,
43+
args=(cfg, distributed, world_size, torch_sharing),
44+
nprocs=world_size,
4345
)
4446

4547

@@ -58,7 +60,7 @@ def run_single(cfg):
5860
inference_single(cfg)
5961

6062

61-
def train_single(rank, cfg, distributed=False, world_size=None):
63+
def train_single(rank, cfg, distributed=False, world_size=None, torch_sharing=None):
6264
"""Train a model in a single process.
6365
6466
Parameters
@@ -71,6 +73,8 @@ def train_single(rank, cfg, distributed=False, world_size=None):
7173
If `True`, distribute the training process
7274
world_size : int, optional
7375
Number of devices to use in the distributed training process
76+
torch_sharing : str or None, optional
77+
File sharing strategy for torch distributed training
7478
"""
7579
# Training always requires torch
7680
if not TORCH_AVAILABLE:
@@ -79,6 +83,10 @@ def train_single(rank, cfg, distributed=False, world_size=None):
7983
"Install with: pip install spine-ml[model]"
8084
)
8185

86+
# Set the torch sharing strategy if needed
87+
if distributed and torch_sharing is not None:
88+
torch.multiprocessing.set_sharing_strategy(torch_sharing)
89+
8290
# If distributed, setup the process group
8391
if distributed:
8492
setup_ddp(rank, world_size)
@@ -95,8 +103,7 @@ def train_single(rank, cfg, distributed=False, world_size=None):
95103

96104

97105
def inference_single(cfg):
98-
"""
99-
Execute a model in inference mode in a single process
106+
"""Execute a model in inference mode in a single process.
100107
101108
Parameters
102109
----------
@@ -121,7 +128,7 @@ def inference_single(cfg):
121128

122129
# Loop over the weights, run the inference loop
123130
for weight in weights:
124-
if weight is not None and not preloaded:
131+
if driver.model is not None and weight is not None and not preloaded:
125132
driver.model.load_weights(weight)
126133
driver.initialize_log()
127134

@@ -145,6 +152,8 @@ def process_world(base, **kwargs):
145152
If `True`, distribute the training process
146153
world_size : int
147154
Number of devices to use in the distributed training process
155+
torch_sharing : str or None
156+
File sharing strategy for torch distributed training
148157
"""
149158
# Set the verbosity of the logger
150159
verbosity = base.get("verbosity", "info")
@@ -159,7 +168,14 @@ def process_world(base, **kwargs):
159168
world_size < 2 or distributed
160169
), "Cannot run process on multiple GPUs without distributing it."
161170

162-
return distributed, world_size
171+
# If distributed, check what the file sharing strategy is
172+
torch_sharing = base.get("torch_sharing_strategy", None)
173+
assert not torch_sharing or torch_sharing in ("file_system", "file_descriptor"), (
174+
"torch_sharing_strategy must be one of: "
175+
"'file_system', 'file_descriptor', or None"
176+
)
177+
178+
return distributed, world_size, torch_sharing
163179

164180

165181
def setup_ddp(rank, world_size, backend="nccl"):

0 commit comments

Comments
 (0)