130130export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
131131printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
132132
133+ printf "INFO: /opt/ml/input/config/resourceconfig.json:\\ n"
134+ cat /opt/ml/input/config/resourceconfig.json
133135
134136printf "INFO: Bootstraping runtime environment.\\ n"
135137python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
138+ source /opt/ml/input/sm_training.env
136139
137140if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
138141then
155158 fi
156159
157160 printf "INFO: Invoking remote function inside conda environment: $conda_env.\\ n"
161+ printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\ n"
162+
158163 $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
159164else
160165 printf "INFO: No conda env provided. Invoking remote function\\ n"
166+ printf "INFO: python -m sagemaker.remote_function.invoke_function \\ n"
167+
161168 python -m sagemaker.remote_function.invoke_function "$@"
162169fi
163170"""
175182export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
176183printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
177184
185+ printf "INFO: /opt/ml/input/config/resourceconfig.json:\\ n"
186+ cat /opt/ml/input/config/resourceconfig.json
178187
179188printf "INFO: Bootstraping runtime environment.\\ n"
180189python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
190+ source /opt/ml/input/sm_training.env
181191
182192if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
183193then
200210 fi
201211
202212 printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\ n"
203- $conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
213+ printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
214+ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
215+ -m sagemaker.remote_function.invoke_function \\ n"
216+
217+ $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
218+ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
204219 -m sagemaker.remote_function.invoke_function "$@"
205220else
206221 printf "INFO: No conda env provided. Invoking remote function with torchrun\\ n"
207- torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
222+ printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
223+ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\ n"
224+
225+ torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
226+ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
208227fi
209228"""
210229
@@ -263,7 +282,6 @@ def __init__(
263282 use_spot_instances = False ,
264283 max_wait_time_in_seconds = None ,
265284 use_torchrun = False ,
266- nproc_per_node = 1 ,
267285 ):
268286 """Initialize a _JobSettings instance which configures the remote job.
269287
@@ -604,7 +622,6 @@ def __init__(
604622 self .tags = self .sagemaker_session ._append_sagemaker_config_tags (tags , REMOTE_FUNCTION_TAGS )
605623
606624 self .use_torchrun = use_torchrun
607- self .nproc_per_node = nproc_per_node
608625
609626 @staticmethod
610627 def _get_default_image (session ):
@@ -732,6 +749,8 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
732749 )
733750
734751 logger .info ("Creating job: %s" , job_name )
752+ logger .info ("Environment variables: %s" , training_job_request ["Environment" ])
753+
735754 job_settings .sagemaker_session .sagemaker_client .create_training_job (** training_job_request )
736755
737756 return _Job (
@@ -776,8 +795,6 @@ def compile(
776795 s3_base_uri = s3_base_uri ,
777796 hmac_key = hmac_key ,
778797 s3_kms_key = job_settings .s3_kms_key ,
779- use_torchrun = job_settings .use_torchrun ,
780- nproc_per_node = job_settings .nproc_per_node ,
781798 )
782799 stored_function .save (func , * func_args , ** func_kwargs )
783800 else :
@@ -790,8 +807,6 @@ def compile(
790807 step_name = step_compilation_context .step_name ,
791808 func_step_s3_dir = step_compilation_context .pipeline_build_time ,
792809 ),
793- use_torchrun = job_settings .use_torchrun ,
794- nproc_per_node = job_settings .nproc_per_node ,
795810 )
796811
797812 stored_function .save_pipeline_step_function (serialized_data )
@@ -931,6 +946,7 @@ def compile(
931946 request_dict ["Environment" ].update ({"REMOTE_FUNCTION_SECRET_KEY" : hmac_key })
932947
933948 extended_request = _extend_spark_config_to_request (request_dict , job_settings , s3_base_uri )
949+ extended_request = _extend_torchrun_to_request (extended_request , job_settings )
934950
935951 return extended_request
936952
@@ -1011,7 +1027,6 @@ def _prepare_and_upload_runtime_scripts(
10111027 s3_kms_key : str ,
10121028 sagemaker_session : Session ,
10131029 use_torchrun : bool = False ,
1014- nproc_per_node : int = 1 ,
10151030):
10161031 """Copy runtime scripts to a folder and upload to S3.
10171032
@@ -1029,8 +1044,6 @@ def _prepare_and_upload_runtime_scripts(
10291044 sagemaker_session (str): SageMaker boto client session.
10301045
10311046 use_torchrun (bool): Whether to use torchrun or not.
1032-
1033- nproc_per_node (int): Number of processes per node.
10341047 """
10351048
10361049 from sagemaker .workflow .utilities import load_step_compilation_context
@@ -1054,7 +1067,6 @@ def _prepare_and_upload_runtime_scripts(
10541067
10551068 if use_torchrun :
10561069 entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1057- entry_point_script = entry_point_script .replace ("$NPROC_PER_NODE" , str (nproc_per_node ))
10581070
10591071 with open (entrypoint_script_path , "w" , newline = "\n " ) as file :
10601072 file .writelines (entry_point_script )
@@ -1094,7 +1106,6 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
10941106 s3_kms_key = job_settings .s3_kms_key ,
10951107 sagemaker_session = job_settings .sagemaker_session ,
10961108 use_torchrun = job_settings .use_torchrun ,
1097- nproc_per_node = job_settings .nproc_per_node ,
10981109 )
10991110
11001111 input_data_config = [
@@ -1435,6 +1446,35 @@ def _upload_serialized_spark_configuration(
14351446 return config_file_s3_uri
14361447
14371448
1449+ def _extend_torchrun_to_request (
1450+ request_dict : Dict ,
1451+ job_settings : _JobSettings ,
1452+ ) -> Dict :
1453+ """Extend the create training job request with torchrun configuration.
1454+
1455+ Args:
1456+ request_dict (Dict): create training job request dict.
1457+ job_settings (_JobSettings): the job settings.
1458+ """
1459+ use_torchrun = job_settings .use_torchrun
1460+ instance_count = job_settings .instance_count
1461+
1462+ if not use_torchrun :
1463+ return request_dict
1464+
1465+ if instance_count == 1 :
1466+ return request_dict
1467+
1468+ extended_request = request_dict .copy ()
1469+
1470+ for input_channel in extended_request ["InputDataConfig" ]:
1471+ s3_data_source = input_channel ["DataSource" ].get ("S3DataSource" , None )
1472+ if s3_data_source :
1473+ s3_data_source ["S3DataDistributionType" ] = "FullyReplicated"
1474+
1475+ return extended_request
1476+
1477+
14381478def _extend_spark_config_to_request (
14391479 request_dict : Dict ,
14401480 job_settings : _JobSettings ,
0 commit comments