@@ -90,7 +90,8 @@ def remote(
90
90
spark_config : SparkConfig = None ,
91
91
use_spot_instances = False ,
92
92
max_wait_time_in_seconds = None ,
93
- use_torchrun = False ,
93
+ use_torchrun : bool = False ,
94
+ use_mpirun : bool = False ,
94
95
nproc_per_node : Optional [int ] = None ,
95
96
):
96
97
"""Decorator for running the annotated function as a SageMaker training job.
@@ -207,7 +208,8 @@ def remote(
207
208
files are accepted and uploaded to S3.
208
209
209
210
instance_count (int): The number of instances to use. Defaults to 1.
210
- NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
211
+ NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and
212
+ mpirun utilities
211
213
212
214
instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
213
215
the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
@@ -284,6 +286,9 @@ def remote(
284
286
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
285
287
Defaults to ``False``.
286
288
289
+ use_mpirun (bool): Specifies whether to use mpirun for distributed training.
290
+ Defaults to ``False``.
291
+
287
292
nproc_per_node (Optional int): Specifies the number of processes per node for
288
293
distributed training. Defaults to ``None``.
289
294
This is defined automatically configured on the instance type.
@@ -320,19 +325,21 @@ def _remote(func):
320
325
use_spot_instances = use_spot_instances ,
321
326
max_wait_time_in_seconds = max_wait_time_in_seconds ,
322
327
use_torchrun = use_torchrun ,
328
+ use_mpirun = use_mpirun ,
323
329
nproc_per_node = nproc_per_node ,
324
330
)
325
331
326
332
@functools .wraps (func )
327
333
def wrapper (* args , ** kwargs ):
328
334
329
335
if instance_count > 1 and not (
330
- (spark_config is not None and not use_torchrun )
331
- or (spark_config is None and use_torchrun )
336
+ (spark_config is not None and not use_torchrun and not use_mpirun )
337
+ or (spark_config is None and use_torchrun and not use_mpirun )
338
+ or (spark_config is None and not use_torchrun and use_mpirun )
332
339
):
333
340
raise ValueError (
334
341
"Remote function do not support training on multi instances "
335
- + "without spark_config or use_torchrun. "
342
+ + "without spark_config or use_torchrun or use_mpirun . "
336
343
+ "Please provide instance_count = 1"
337
344
)
338
345
@@ -536,7 +543,8 @@ def __init__(
536
543
spark_config : SparkConfig = None ,
537
544
use_spot_instances = False ,
538
545
max_wait_time_in_seconds = None ,
539
- use_torchrun = False ,
546
+ use_torchrun : bool = False ,
547
+ use_mpirun : bool = False ,
540
548
nproc_per_node : Optional [int ] = None ,
541
549
):
542
550
"""Constructor for RemoteExecutor
@@ -650,7 +658,8 @@ def __init__(
650
658
files are accepted and uploaded to S3.
651
659
652
660
instance_count (int): The number of instances to use. Defaults to 1.
653
- NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
661
+ NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and
662
+ mpirun utilities
654
663
655
664
instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
656
665
the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
@@ -730,6 +739,9 @@ def __init__(
730
739
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
731
740
Defaults to ``False``.
732
741
742
+ use_mpirun (bool): Specifies whether to use mpirun for distributed training.
743
+ Defaults to ``False``.
744
+
733
745
nproc_per_node (Optional int): Specifies the number of processes per node for
734
746
distributed training. Defaults to ``None``.
735
747
This is defined automatically configured on the instance type.
@@ -740,12 +752,13 @@ def __init__(
740
752
raise ValueError ("max_parallel_jobs must be greater than 0." )
741
753
742
754
if instance_count > 1 and not (
743
- (spark_config is not None and not use_torchrun )
744
- or (spark_config is None and use_torchrun )
755
+ (spark_config is not None and not use_torchrun and not use_mpirun )
756
+ or (spark_config is None and use_torchrun and not use_mpirun )
757
+ or (spark_config is None and not use_torchrun and use_mpirun )
745
758
):
746
759
raise ValueError (
747
760
"Remote function do not support training on multi instances "
748
- + "without spark_config or use_torchrun. "
761
+ + "without spark_config or use_torchrun or use_mpirun . "
749
762
+ "Please provide instance_count = 1"
750
763
)
751
764
@@ -778,6 +791,7 @@ def __init__(
778
791
use_spot_instances = use_spot_instances ,
779
792
max_wait_time_in_seconds = max_wait_time_in_seconds ,
780
793
use_torchrun = use_torchrun ,
794
+ use_mpirun = use_mpirun ,
781
795
nproc_per_node = nproc_per_node ,
782
796
)
783
797
0 commit comments