130
130
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
131
131
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
132
132
133
+ printf "INFO: /opt/ml/input/config/resourceconfig.json:\\ n"
134
+ cat /opt/ml/input/config/resourceconfig.json
133
135
134
136
printf "INFO: Bootstraping runtime environment.\\ n"
135
137
python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
138
+ source /opt/ml/input/sm_training.env
136
139
137
140
if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
138
141
then
155
158
fi
156
159
157
160
printf "INFO: Invoking remote function inside conda environment: $conda_env.\\ n"
161
+ printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\ n"
162
+
158
163
$conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
159
164
else
160
165
printf "INFO: No conda env provided. Invoking remote function\\ n"
166
+ printf "INFO: python -m sagemaker.remote_function.invoke_function \\ n"
167
+
161
168
python -m sagemaker.remote_function.invoke_function "$@"
162
169
fi
163
170
"""
175
182
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
176
183
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\ n"
177
184
185
+ printf "INFO: /opt/ml/input/config/resourceconfig.json:\\ n"
186
+ cat /opt/ml/input/config/resourceconfig.json
178
187
179
188
printf "INFO: Bootstraping runtime environment.\\ n"
180
189
python /opt/ml/input/data/{ RUNTIME_SCRIPTS_CHANNEL_NAME } /{ BOOTSTRAP_SCRIPT_NAME } "$@"
190
+ source /opt/ml/input/sm_training.env
181
191
182
192
if [ -d { JOB_REMOTE_FUNCTION_WORKSPACE } ]
183
193
then
200
210
fi
201
211
202
212
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\ n"
203
- $conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
213
+ printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
214
+ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
215
+ -m sagemaker.remote_function.invoke_function \\ n"
216
+
217
+ $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
218
+ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
204
219
-m sagemaker.remote_function.invoke_function "$@"
205
220
else
206
221
printf "INFO: No conda env provided. Invoking remote function with torchrun\\ n"
207
- torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
222
+ printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
223
+ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\ n"
224
+
225
+ torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
226
+ --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
208
227
fi
209
228
"""
210
229
@@ -263,7 +282,6 @@ def __init__(
263
282
use_spot_instances = False ,
264
283
max_wait_time_in_seconds = None ,
265
284
use_torchrun = False ,
266
- nproc_per_node = 1 ,
267
285
):
268
286
"""Initialize a _JobSettings instance which configures the remote job.
269
287
@@ -604,7 +622,6 @@ def __init__(
604
622
self .tags = self .sagemaker_session ._append_sagemaker_config_tags (tags , REMOTE_FUNCTION_TAGS )
605
623
606
624
self .use_torchrun = use_torchrun
607
- self .nproc_per_node = nproc_per_node
608
625
609
626
@staticmethod
610
627
def _get_default_image (session ):
@@ -732,6 +749,8 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
732
749
)
733
750
734
751
logger .info ("Creating job: %s" , job_name )
752
+ logger .info ("Environment variables: %s" , training_job_request ["Environment" ])
753
+
735
754
job_settings .sagemaker_session .sagemaker_client .create_training_job (** training_job_request )
736
755
737
756
return _Job (
@@ -776,8 +795,6 @@ def compile(
776
795
s3_base_uri = s3_base_uri ,
777
796
hmac_key = hmac_key ,
778
797
s3_kms_key = job_settings .s3_kms_key ,
779
- use_torchrun = job_settings .use_torchrun ,
780
- nproc_per_node = job_settings .nproc_per_node ,
781
798
)
782
799
stored_function .save (func , * func_args , ** func_kwargs )
783
800
else :
@@ -790,8 +807,6 @@ def compile(
790
807
step_name = step_compilation_context .step_name ,
791
808
func_step_s3_dir = step_compilation_context .pipeline_build_time ,
792
809
),
793
- use_torchrun = job_settings .use_torchrun ,
794
- nproc_per_node = job_settings .nproc_per_node ,
795
810
)
796
811
797
812
stored_function .save_pipeline_step_function (serialized_data )
@@ -931,6 +946,7 @@ def compile(
931
946
request_dict ["Environment" ].update ({"REMOTE_FUNCTION_SECRET_KEY" : hmac_key })
932
947
933
948
extended_request = _extend_spark_config_to_request (request_dict , job_settings , s3_base_uri )
949
+ extended_request = _extend_torchrun_to_request (extended_request , job_settings )
934
950
935
951
return extended_request
936
952
@@ -1011,7 +1027,6 @@ def _prepare_and_upload_runtime_scripts(
1011
1027
s3_kms_key : str ,
1012
1028
sagemaker_session : Session ,
1013
1029
use_torchrun : bool = False ,
1014
- nproc_per_node : int = 1 ,
1015
1030
):
1016
1031
"""Copy runtime scripts to a folder and upload to S3.
1017
1032
@@ -1029,8 +1044,6 @@ def _prepare_and_upload_runtime_scripts(
1029
1044
sagemaker_session (str): SageMaker boto client session.
1030
1045
1031
1046
use_torchrun (bool): Whether to use torchrun or not.
1032
-
1033
- nproc_per_node (int): Number of processes per node.
1034
1047
"""
1035
1048
1036
1049
from sagemaker .workflow .utilities import load_step_compilation_context
@@ -1054,7 +1067,6 @@ def _prepare_and_upload_runtime_scripts(
1054
1067
1055
1068
if use_torchrun :
1056
1069
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1057
- entry_point_script = entry_point_script .replace ("$NPROC_PER_NODE" , str (nproc_per_node ))
1058
1070
1059
1071
with open (entrypoint_script_path , "w" , newline = "\n " ) as file :
1060
1072
file .writelines (entry_point_script )
@@ -1094,7 +1106,6 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
1094
1106
s3_kms_key = job_settings .s3_kms_key ,
1095
1107
sagemaker_session = job_settings .sagemaker_session ,
1096
1108
use_torchrun = job_settings .use_torchrun ,
1097
- nproc_per_node = job_settings .nproc_per_node ,
1098
1109
)
1099
1110
1100
1111
input_data_config = [
@@ -1435,6 +1446,35 @@ def _upload_serialized_spark_configuration(
1435
1446
return config_file_s3_uri
1436
1447
1437
1448
1449
+ def _extend_torchrun_to_request (
1450
+ request_dict : Dict ,
1451
+ job_settings : _JobSettings ,
1452
+ ) -> Dict :
1453
+ """Extend the create training job request with torchrun configuration.
1454
+
1455
+ Args:
1456
+ request_dict (Dict): create training job request dict.
1457
+ job_settings (_JobSettings): the job settings.
1458
+ """
1459
+ use_torchrun = job_settings .use_torchrun
1460
+ instance_count = job_settings .instance_count
1461
+
1462
+ if not use_torchrun :
1463
+ return request_dict
1464
+
1465
+ if instance_count == 1 :
1466
+ return request_dict
1467
+
1468
+ extended_request = request_dict .copy ()
1469
+
1470
+ for input_channel in extended_request ["InputDataConfig" ]:
1471
+ s3_data_source = input_channel ["DataSource" ].get ("S3DataSource" , None )
1472
+ if s3_data_source :
1473
+ s3_data_source ["S3DataDistributionType" ] = "FullyReplicated"
1474
+
1475
+ return extended_request
1476
+
1477
+
1438
1478
def _extend_spark_config_to_request (
1439
1479
request_dict : Dict ,
1440
1480
job_settings : _JobSettings ,
0 commit comments