8181
8282# runtime script names
8383BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py"
84+ MPI_UTILS_SCRIPT_NAME = "mpi_utils_remote.py"
8485ENTRYPOINT_SCRIPT_NAME = "job_driver.sh"
8586PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
8687RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py"
167168fi
168169"""
169170
171+ ENTRYPOINT_MPIRUN_SCRIPT = f"""
172+ #!/bin/bash
173+
174+ # Entry point for bootstrapping runtime environment and invoking remote function with mpirun
175+
176+ set -eu
177+
178+ PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}}
179+ export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs
180+ printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\ n"
181+ export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
182+ printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
183+
184+ printf "INFO: /opt/ml/input/config/resourceconfig.json:\\ n"
185+ cat /opt/ml/input/config/resourceconfig.json
186+
187+ printf "INFO: Bootstraping runtime environment.\\ n"
188+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
189+ source /opt/ml/input/sm_training.env
190+
191+ if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
192+ then
193+ if [ -f "remote_function_conda_env.txt" ]
194+ then
195+ cp remote_function_conda_env.txt { JOB_REMOTE_FUNCTION_WORKSPACE } /remote_function_conda_env.txt
196+ fi
197+ printf "INFO: Changing workspace to { JOB_REMOTE_FUNCTION_WORKSPACE } .\\ n"
198+ cd { JOB_REMOTE_FUNCTION_WORKSPACE }
199+ fi
200+
201+ if [ -f "remote_function_conda_env.txt" ]
202+ then
203+ conda_env=$(cat remote_function_conda_env.txt)
204+
205+ if which mamba >/dev/null; then
206+ conda_exe="mamba"
207+ else
208+ conda_exe="conda"
209+ fi
210+
211+ if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then
212+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME }
213+
214+ printf "INFO: Invoking remote function with mpirun inside conda environment: $conda_env.\\ n"
215+ printf "INFO: $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
216+ --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
217+ -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
218+ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
219+ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
220+
221+ python -m mpi4py -m sagemaker.remote_function.invoke_function \\ n"
222+ $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
223+ --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
224+ -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
225+ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
226+ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
227+ $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
228+ python -m mpi4py -m sagemaker.remote_function.invoke_function "$@"
229+
230+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME } --job_ended 1
231+ else
232+ printf "INFO: This is the instance $SM_CURRENT_HOST. mpirun command terminated\\ n"
233+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME }
234+ fi
235+ else
236+ if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then
237+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME }
238+
239+ printf "INFO: No conda env provided. Invoking remote function with mpirun\\ n"
240+ printf "INFO: mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
241+ --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
242+ -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
243+ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
244+ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
245+ $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
246+ python -m mpi4py -m sagemaker.remote_function.invoke_function \\ n"
247+
248+ mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
249+ --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
250+ -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
251+ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
252+ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
253+ $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
254+ python -m mpi4py -m sagemaker.remote_function.invoke_function "$@"
255+
256+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME } --job_ended 1
257+ else
258+ printf "INFO: This is the instance $SM_CURRENT_HOST.\\ n"
259+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ MPI_UTILS_SCRIPT_NAME }
260+ fi
261+ fi
262+ """
263+
170264ENTRYPOINT_TORCHRUN_SCRIPT = f"""
171265#!/bin/bash
172266
211305 printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
212306 --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
213307 -m sagemaker.remote_function.invoke_function \\ n"
308+
214309 $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
215310 --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
216311 -m sagemaker.remote_function.invoke_function "$@"
217312else
218313 printf "INFO: No conda env provided. Invoking remote function with torchrun\\ n"
219314 printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
220315 --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\ n"
316+
221317 torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
222318 --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
223319fi
@@ -278,6 +374,7 @@ def __init__(
278374 use_spot_instances = False ,
279375 max_wait_time_in_seconds = None ,
280376 use_torchrun : bool = False ,
377+ use_mpirun : bool = False ,
281378 nproc_per_node : Optional [int ] = None ,
282379 ):
283380 """Initialize a _JobSettings instance which configures the remote job.
@@ -464,6 +561,9 @@ def __init__(
464561 use_torchrun (bool): Specifies whether to use torchrun for distributed training.
465562 Defaults to ``False``.
466563
564+ use_mpirun (bool): Specifies whether to use mpirun for distributed training.
565+ Defaults to ``False``.
566+
467567 nproc_per_node (Optional int): Specifies the number of processes per node for
468568 distributed training. Defaults to ``None``.
469569 This is defined automatically configured on the instance type.
@@ -626,6 +726,7 @@ def __init__(
626726 self .tags = self .sagemaker_session ._append_sagemaker_config_tags (tags , REMOTE_FUNCTION_TAGS )
627727
628728 self .use_torchrun = use_torchrun
729+ self .use_mpirun = use_mpirun
629730 self .nproc_per_node = nproc_per_node
630731
631732 @staticmethod
@@ -874,6 +975,12 @@ def compile(
874975 ).to_string (),
875976 ]
876977 )
978+ if job_settings .use_torchrun :
979+ container_args .extend (["--distribution" , "torchrun" ])
980+ elif job_settings .use_mpirun :
981+ container_args .extend (["--distribution" , "mpirun" ])
982+ if job_settings .nproc_per_node is not None and int (job_settings .nproc_per_node ) > 0 :
983+ container_args .extend (["--user_nproc_per_node" , str (job_settings .nproc_per_node )])
877984 if job_settings .s3_kms_key :
878985 container_args .extend (["--s3_kms_key" , job_settings .s3_kms_key ])
879986
@@ -950,6 +1057,7 @@ def compile(
9501057 request_dict ["Environment" ].update ({"REMOTE_FUNCTION_SECRET_KEY" : hmac_key })
9511058
9521059 extended_request = _extend_spark_config_to_request (request_dict , job_settings , s3_base_uri )
1060+ extended_request = _extend_mpirun_to_request (extended_request , job_settings )
9531061 extended_request = _extend_torchrun_to_request (extended_request , job_settings )
9541062
9551063 return extended_request
@@ -1031,7 +1139,7 @@ def _prepare_and_upload_runtime_scripts(
10311139 s3_kms_key : str ,
10321140 sagemaker_session : Session ,
10331141 use_torchrun : bool = False ,
1034- nproc_per_node : Optional [ int ] = None ,
1142+ use_mpirun : bool = False ,
10351143):
10361144 """Copy runtime scripts to a folder and upload to S3.
10371145
@@ -1050,6 +1158,8 @@ def _prepare_and_upload_runtime_scripts(
10501158
10511159 use_torchrun (bool): Whether to use torchrun or not.
10521160
1161+ use_mpirun (bool): Whether to use mpirun or not.
1162+
10531163 nproc_per_node (Optional[int]): Number of processes per node
10541164 """
10551165
@@ -1075,23 +1185,25 @@ def _prepare_and_upload_runtime_scripts(
10751185 if use_torchrun :
10761186 entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
10771187
1078- if nproc_per_node is not None and nproc_per_node > 0 :
1079- entry_point_script = entry_point_script .replace (
1080- "$SM_NPROC_PER_NODE" , str (nproc_per_node )
1081- )
1188+ if use_mpirun :
1189+ entry_point_script = ENTRYPOINT_MPIRUN_SCRIPT
10821190
10831191 with open (entrypoint_script_path , "w" , newline = "\n " ) as file :
10841192 file .writelines (entry_point_script )
10851193
10861194 bootstrap_script_path = os .path .join (
10871195 os .path .dirname (__file__ ), "runtime_environment" , BOOTSTRAP_SCRIPT_NAME
10881196 )
1197+ mpi_utils_path = os .path .join (
1198+ os .path .dirname (__file__ ), "runtime_environment" , MPI_UTILS_SCRIPT_NAME
1199+ )
10891200 runtime_manager_script_path = os .path .join (
10901201 os .path .dirname (__file__ ), "runtime_environment" , RUNTIME_MANAGER_SCRIPT_NAME
10911202 )
10921203
10931204 # copy runtime scripts to tmpdir
10941205 shutil .copy2 (bootstrap_script_path , bootstrap_scripts )
1206+ shutil .copy2 (mpi_utils_path , bootstrap_scripts )
10951207 shutil .copy2 (runtime_manager_script_path , bootstrap_scripts )
10961208
10971209 upload_path = S3Uploader .upload (
@@ -1118,7 +1230,7 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
11181230 s3_kms_key = job_settings .s3_kms_key ,
11191231 sagemaker_session = job_settings .sagemaker_session ,
11201232 use_torchrun = job_settings .use_torchrun ,
1121- nproc_per_node = job_settings .nproc_per_node ,
1233+ use_mpirun = job_settings .use_mpirun ,
11221234 )
11231235
11241236 input_data_config = [
@@ -1459,6 +1571,35 @@ def _upload_serialized_spark_configuration(
14591571 return config_file_s3_uri
14601572
14611573
1574+ def _extend_mpirun_to_request (
1575+ request_dict : Dict ,
1576+ job_settings : _JobSettings ,
1577+ ) -> Dict :
1578+ """Extend the create training job request with mpirun configuration.
1579+
1580+ Args:
1581+ request_dict (Dict): create training job request dict.
1582+ job_settings (_JobSettings): the job settings.
1583+ """
1584+ use_mpirun = job_settings .use_mpirun
1585+ instance_count = job_settings .instance_count
1586+
1587+ if not use_mpirun :
1588+ return request_dict
1589+
1590+ if instance_count == 1 :
1591+ return request_dict
1592+
1593+ extended_request = request_dict .copy ()
1594+
1595+ for input_channel in extended_request ["InputDataConfig" ]:
1596+ s3_data_source = input_channel ["DataSource" ].get ("S3DataSource" , None )
1597+ if s3_data_source :
1598+ s3_data_source ["S3DataDistributionType" ] = "FullyReplicated"
1599+
1600+ return extended_request
1601+
1602+
14621603def _extend_torchrun_to_request (
14631604 request_dict : Dict ,
14641605 job_settings : _JobSettings ,
0 commit comments