130130export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
131131printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
132132
133+ printf "INFO: /opt/ml/input/config/resourceconfig.json:\\ n"
134+ cat /opt/ml/input/config/resourceconfig.json
133135
134136printf "INFO: Bootstraping runtime environment.\\ n"
135137python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
138+ source /opt/ml/input/sm_training.env
136139
137140if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
138141then
155158 fi
156159
157160 printf "INFO: Invoking remote function inside conda environment: $conda_env.\\ n"
161+ printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\ n"
158162 $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
159163else
160164 printf "INFO: No conda env provided. Invoking remote function\\ n"
165+ printf "INFO: python -m sagemaker.remote_function.invoke_function \\ n"
161166 python -m sagemaker.remote_function.invoke_function "$@"
162167fi
163168"""
175180export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
176181printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
177182
183+ printf "INFO: /opt/ml/input/config/resourceconfig.json:\\ n"
184+ cat /opt/ml/input/config/resourceconfig.json
178185
179186printf "INFO: Bootstraping runtime environment.\\ n"
180187python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
188+ source /opt/ml/input/sm_training.env
181189
182190if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
183191then
200208 fi
201209
202210 printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\ n"
203- $conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
211+ printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
212+ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
213+ -m sagemaker.remote_function.invoke_function \\ n"
214+ $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
215+ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
204216 -m sagemaker.remote_function.invoke_function "$@"
205217else
206218 printf "INFO: No conda env provided. Invoking remote function with torchrun\\ n"
207- torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
219+ printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
220+ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\ n"
221+ torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
222+ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
208223fi
209224"""
210225
@@ -262,8 +277,8 @@ def __init__(
262277 spark_config : SparkConfig = None ,
263278 use_spot_instances = False ,
264279 max_wait_time_in_seconds = None ,
265- use_torchrun = False ,
266- nproc_per_node = 1 ,
280+ use_torchrun : bool = False ,
281+ nproc_per_node : Optional [ int ] = None ,
267282 ):
268283 """Initialize a _JobSettings instance which configures the remote job.
269284
@@ -445,6 +460,13 @@ def __init__(
445460 max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
446461 After this amount of time Amazon SageMaker will stop waiting for managed spot
447462 training job to complete. Defaults to ``None``.
463+
464+ use_torchrun (bool): Specifies whether to use torchrun for distributed training.
465+ Defaults to ``False``.
466+
467+ nproc_per_node (Optional int): Specifies the number of processes per node for
468+ distributed training. Defaults to ``None``.
469+ This is defined automatically configured on the instance type.
448470 """
449471 self .sagemaker_session = sagemaker_session or Session ()
450472 self .environment_variables = resolve_value_from_config (
@@ -732,6 +754,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
732754 )
733755
734756 logger .info ("Creating job: %s" , job_name )
757+
735758 job_settings .sagemaker_session .sagemaker_client .create_training_job (** training_job_request )
736759
737760 return _Job (
@@ -776,8 +799,6 @@ def compile(
776799 s3_base_uri = s3_base_uri ,
777800 hmac_key = hmac_key ,
778801 s3_kms_key = job_settings .s3_kms_key ,
779- use_torchrun = job_settings .use_torchrun ,
780- nproc_per_node = job_settings .nproc_per_node ,
781802 )
782803 stored_function .save (func , * func_args , ** func_kwargs )
783804 else :
@@ -790,8 +811,6 @@ def compile(
790811 step_name = step_compilation_context .step_name ,
791812 func_step_s3_dir = step_compilation_context .pipeline_build_time ,
792813 ),
793- use_torchrun = job_settings .use_torchrun ,
794- nproc_per_node = job_settings .nproc_per_node ,
795814 )
796815
797816 stored_function .save_pipeline_step_function (serialized_data )
@@ -931,6 +950,7 @@ def compile(
931950 request_dict ["Environment" ].update ({"REMOTE_FUNCTION_SECRET_KEY" : hmac_key })
932951
933952 extended_request = _extend_spark_config_to_request (request_dict , job_settings , s3_base_uri )
953+ extended_request = _extend_torchrun_to_request (extended_request , job_settings )
934954
935955 return extended_request
936956
@@ -1011,7 +1031,7 @@ def _prepare_and_upload_runtime_scripts(
10111031 s3_kms_key : str ,
10121032 sagemaker_session : Session ,
10131033 use_torchrun : bool = False ,
1014- nproc_per_node : int = 1 ,
1034+ nproc_per_node : Optional [ int ] = None ,
10151035):
10161036 """Copy runtime scripts to a folder and upload to S3.
10171037
@@ -1030,7 +1050,7 @@ def _prepare_and_upload_runtime_scripts(
10301050
10311051 use_torchrun (bool): Whether to use torchrun or not.
10321052
1033- nproc_per_node (int): Number of processes per node.
1053+ nproc_per_node (Optional[ int] ): Number of processes per node
10341054 """
10351055
10361056 from sagemaker .workflow .utilities import load_step_compilation_context
@@ -1054,7 +1074,11 @@ def _prepare_and_upload_runtime_scripts(
10541074
10551075 if use_torchrun :
10561076 entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1057- entry_point_script = entry_point_script .replace ("$NPROC_PER_NODE" , str (nproc_per_node ))
1077+
1078+ if nproc_per_node is not None and nproc_per_node > 0 :
1079+ entry_point_script = entry_point_script .replace (
1080+ "$SM_NPROC_PER_NODE" , str (nproc_per_node )
1081+ )
10581082
10591083 with open (entrypoint_script_path , "w" , newline = "\n " ) as file :
10601084 file .writelines (entry_point_script )
@@ -1435,6 +1459,35 @@ def _upload_serialized_spark_configuration(
14351459 return config_file_s3_uri
14361460
14371461
1462+ def _extend_torchrun_to_request (
1463+ request_dict : Dict ,
1464+ job_settings : _JobSettings ,
1465+ ) -> Dict :
1466+ """Extend the create training job request with torchrun configuration.
1467+
1468+ Args:
1469+ request_dict (Dict): create training job request dict.
1470+ job_settings (_JobSettings): the job settings.
1471+ """
1472+ use_torchrun = job_settings .use_torchrun
1473+ instance_count = job_settings .instance_count
1474+
1475+ if not use_torchrun :
1476+ return request_dict
1477+
1478+ if instance_count == 1 :
1479+ return request_dict
1480+
1481+ extended_request = request_dict .copy ()
1482+
1483+ for input_channel in extended_request ["InputDataConfig" ]:
1484+ s3_data_source = input_channel ["DataSource" ].get ("S3DataSource" , None )
1485+ if s3_data_source :
1486+ s3_data_source ["S3DataDistributionType" ] = "FullyReplicated"
1487+
1488+ return extended_request
1489+
1490+
14381491def _extend_spark_config_to_request (
14391492 request_dict : Dict ,
14401493 job_settings : _JobSettings ,
0 commit comments