162162fi
163163"""
164164
165+ ENTRYPOINT_TORCHRUN_SCRIPT = f"""
166+ #!/bin/bash
167+
168+ # Entry point for bootstrapping runtime environment and invoking remote function with torchrun
169+
170+ set -eu
171+
172+ PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}}
173+ export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs
174+ printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\ n"
175+ export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
176+ printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
177+
178+
179+ printf "INFO: Bootstraping runtime environment.\\ n"
180+ python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
181+
182+ if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
183+ then
184+ if [ -f "remote_function_conda_env.txt" ]
185+ then
186+ cp remote_function_conda_env.txt { JOB_REMOTE_FUNCTION_WORKSPACE } /remote_function_conda_env.txt
187+ fi
188+ printf "INFO: Changing workspace to { JOB_REMOTE_FUNCTION_WORKSPACE } .\\ n"
189+ cd { JOB_REMOTE_FUNCTION_WORKSPACE }
190+ fi
191+
192+ if [ -f "remote_function_conda_env.txt" ]
193+ then
194+ conda_env=$(cat remote_function_conda_env.txt)
195+
196+ if which mamba >/dev/null; then
197+ conda_exe="mamba"
198+ else
199+ conda_exe="conda"
200+ fi
201+
202+ printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\ n"
203+ $conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
204+ -m sagemaker.remote_function.invoke_function "$@"
205+ else
206+ printf "INFO: No conda env provided. Invoking remote function with torchrun\\ n"
207+ torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
208+ fi
209+ """
210+
165211SPARK_ENTRYPOINT_SCRIPT = f"""
166212#!/bin/bash
167213
@@ -216,6 +262,8 @@ def __init__(
216262 spark_config : SparkConfig = None ,
217263 use_spot_instances = False ,
218264 max_wait_time_in_seconds = None ,
265+ use_torchrun = False ,
266+ nproc_per_node = 1 ,
219267 ):
220268 """Initialize a _JobSettings instance which configures the remote job.
221269
@@ -555,6 +603,9 @@ def __init__(
555603 tags = format_tags (tags )
556604 self .tags = self .sagemaker_session ._append_sagemaker_config_tags (tags , REMOTE_FUNCTION_TAGS )
557605
606+ self .use_torchrun = use_torchrun
607+ self .nproc_per_node = nproc_per_node
608+
558609 @staticmethod
559610 def _get_default_image (session ):
560611 """Return Studio notebook image, if in Studio env. Else, base python.
@@ -725,6 +776,8 @@ def compile(
725776 s3_base_uri = s3_base_uri ,
726777 hmac_key = hmac_key ,
727778 s3_kms_key = job_settings .s3_kms_key ,
779+ use_torchrun = job_settings .use_torchrun ,
780+ nproc_per_node = job_settings .nproc_per_node ,
728781 )
729782 stored_function .save (func , * func_args , ** func_kwargs )
730783 else :
@@ -737,6 +790,8 @@ def compile(
737790 step_name = step_compilation_context .step_name ,
738791 func_step_s3_dir = step_compilation_context .pipeline_build_time ,
739792 ),
793+ use_torchrun = job_settings .use_torchrun ,
794+ nproc_per_node = job_settings .nproc_per_node ,
740795 )
741796
742797 stored_function .save_pipeline_step_function (serialized_data )
@@ -951,7 +1006,12 @@ def _get_job_name(job_settings, func):
9511006
9521007
9531008def _prepare_and_upload_runtime_scripts (
954- spark_config : SparkConfig , s3_base_uri : str , s3_kms_key : str , sagemaker_session : Session
1009+ spark_config : SparkConfig ,
1010+ s3_base_uri : str ,
1011+ s3_kms_key : str ,
1012+ sagemaker_session : Session ,
1013+ use_torchrun : bool = False ,
1014+ nproc_per_node : int = 1 ,
9551015):
9561016 """Copy runtime scripts to a folder and upload to S3.
9571017
@@ -967,6 +1027,10 @@ def _prepare_and_upload_runtime_scripts(
9671027 s3_kms_key (str): kms key used to encrypt the files uploaded to S3.
9681028
9691029 sagemaker_session (str): SageMaker boto client session.
1030+
1031+ use_torchrun (bool): Whether to use torchrun or not.
1032+
1033+ nproc_per_node (int): Number of processes per node.
9701034 """
9711035
9721036 from sagemaker .workflow .utilities import load_step_compilation_context
@@ -988,6 +1052,10 @@ def _prepare_and_upload_runtime_scripts(
9881052 )
9891053 shutil .copy2 (spark_script_path , bootstrap_scripts )
9901054
1055+ if use_torchrun :
1056+ entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1057+ entry_point_script = entry_point_script .replace ("$NPROC_PER_NODE" , str (nproc_per_node ))
1058+
9911059 with open (entrypoint_script_path , "w" , newline = "\n " ) as file :
9921060 file .writelines (entry_point_script )
9931061
@@ -1025,6 +1093,8 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
10251093 s3_base_uri = s3_base_uri ,
10261094 s3_kms_key = job_settings .s3_kms_key ,
10271095 sagemaker_session = job_settings .sagemaker_session ,
1096+ use_torchrun = job_settings .use_torchrun ,
1097+ nproc_per_node = job_settings .nproc_per_node ,
10281098 )
10291099
10301100 input_data_config = [
0 commit comments