@@ -90,7 +90,8 @@ def remote(
9090 spark_config : SparkConfig = None ,
9191 use_spot_instances = False ,
9292 max_wait_time_in_seconds = None ,
93- use_torchrun = False ,
93+ use_torchrun : bool = False ,
94+ use_mpirun : bool = False ,
9495 nproc_per_node : Optional [int ] = None ,
9596):
9697 """Decorator for running the annotated function as a SageMaker training job.
@@ -207,7 +208,8 @@ def remote(
207208 files are accepted and uploaded to S3.
208209
209210 instance_count (int): The number of instances to use. Defaults to 1.
210- NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
211+ NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and
212+ mpirun utilities
211213
212214 instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
213215 the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
@@ -284,6 +286,9 @@ def remote(
284286 use_torchrun (bool): Specifies whether to use torchrun for distributed training.
285287 Defaults to ``False``.
286288
289+ use_mpirun (bool): Specifies whether to use mpirun for distributed training.
290+ Defaults to ``False``.
291+
287292 nproc_per_node (Optional int): Specifies the number of processes per node for
288293 distributed training. Defaults to ``None``.
289294 This is defined automatically configured on the instance type.
@@ -320,19 +325,21 @@ def _remote(func):
320325 use_spot_instances = use_spot_instances ,
321326 max_wait_time_in_seconds = max_wait_time_in_seconds ,
322327 use_torchrun = use_torchrun ,
328+ use_mpirun = use_mpirun ,
323329 nproc_per_node = nproc_per_node ,
324330 )
325331
326332 @functools .wraps (func )
327333 def wrapper (* args , ** kwargs ):
328334
329335 if instance_count > 1 and not (
330- (spark_config is not None and not use_torchrun )
331- or (spark_config is None and use_torchrun )
336+ (spark_config is not None and not use_torchrun and not use_mpirun )
337+ or (spark_config is None and use_torchrun and not use_mpirun )
338+ or (spark_config is None and not use_torchrun and use_mpirun )
332339 ):
333340 raise ValueError (
334341 "Remote function do not support training on multi instances "
335- + "without spark_config or use_torchrun. "
342+ + "without spark_config or use_torchrun or use_mpirun . "
336343 + "Please provide instance_count = 1"
337344 )
338345
@@ -536,7 +543,8 @@ def __init__(
536543 spark_config : SparkConfig = None ,
537544 use_spot_instances = False ,
538545 max_wait_time_in_seconds = None ,
539- use_torchrun = False ,
546+ use_torchrun : bool = False ,
547+ use_mpirun : bool = False ,
540548 nproc_per_node : Optional [int ] = None ,
541549 ):
542550 """Constructor for RemoteExecutor
@@ -650,7 +658,8 @@ def __init__(
650658 files are accepted and uploaded to S3.
651659
652660 instance_count (int): The number of instances to use. Defaults to 1.
653- NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
661+ NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and
662+ mpirun utilities
654663
655664 instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run
656665 the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown.
@@ -730,6 +739,9 @@ def __init__(
730739 use_torchrun (bool): Specifies whether to use torchrun for distributed training.
731740 Defaults to ``False``.
732741
742+ use_mpirun (bool): Specifies whether to use mpirun for distributed training.
743+ Defaults to ``False``.
744+
733745 nproc_per_node (Optional int): Specifies the number of processes per node for
734746 distributed training. Defaults to ``None``.
735747 This is defined automatically configured on the instance type.
@@ -740,12 +752,13 @@ def __init__(
740752 raise ValueError ("max_parallel_jobs must be greater than 0." )
741753
742754 if instance_count > 1 and not (
743- (spark_config is not None and not use_torchrun )
744- or (spark_config is None and use_torchrun )
755+ (spark_config is not None and not use_torchrun and not use_mpirun )
756+ or (spark_config is None and use_torchrun and not use_mpirun )
757+ or (spark_config is None and not use_torchrun and use_mpirun )
745758 ):
746759 raise ValueError (
747760 "Remote function do not support training on multi instances "
748- + "without spark_config or use_torchrun. "
761+ + "without spark_config or use_torchrun or use_mpirun . "
749762 + "Please provide instance_count = 1"
750763 )
751764
@@ -778,6 +791,7 @@ def __init__(
778791 use_spot_instances = use_spot_instances ,
779792 max_wait_time_in_seconds = max_wait_time_in_seconds ,
780793 use_torchrun = use_torchrun ,
794+ use_mpirun = use_mpirun ,
781795 nproc_per_node = nproc_per_node ,
782796 )
783797
0 commit comments