Skip to content

Commit 841af92

Browse files
committed
implemented multi-node distribution with @Remote function
1 parent 2102bb7 commit 841af92

File tree

8 files changed

+371
-102
lines changed

8 files changed

+371
-102
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def remote(
9191
use_spot_instances=False,
9292
max_wait_time_in_seconds=None,
9393
use_torchrun=False,
94-
nproc_per_node=1,
9594
):
9695
"""Decorator for running the annotated function as a SageMaker training job.
9796
@@ -283,9 +282,6 @@ def remote(
283282
284283
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
285284
Defaults to ``False``.
286-
287-
nproc_per_node (int): Specifies the number of processes per node for distributed training.
288-
Defaults to ``1``.
289285
"""
290286

291287
def _remote(func):
@@ -319,15 +315,18 @@ def _remote(func):
319315
use_spot_instances=use_spot_instances,
320316
max_wait_time_in_seconds=max_wait_time_in_seconds,
321317
use_torchrun=use_torchrun,
322-
nproc_per_node=nproc_per_node,
323318
)
324319

325320
@functools.wraps(func)
326321
def wrapper(*args, **kwargs):
327322

328-
if instance_count > 1 and not spark_config:
323+
if instance_count > 1 and not (
324+
(spark_config is not None and not use_torchrun)
325+
or (spark_config is None and use_torchrun)
326+
):
329327
raise ValueError(
330-
"Remote function do not support training on multi instances. "
328+
"Remote function do not support training on multi instances "
329+
+ "without spark_config or use_torchrun. "
331330
+ "Please provide instance_count = 1"
332331
)
333332

@@ -532,7 +531,6 @@ def __init__(
532531
use_spot_instances=False,
533532
max_wait_time_in_seconds=None,
534533
use_torchrun=False,
535-
nproc_per_node=1,
536534
):
537535
"""Constructor for RemoteExecutor
538536
@@ -724,18 +722,19 @@ def __init__(
724722
725723
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
726724
Defaults to ``False``.
727-
728-
nproc_per_node (int): Specifies the number of processes per node.
729-
Defaults to ``1``.
730725
"""
731726
self.max_parallel_jobs = max_parallel_jobs
732727

733728
if self.max_parallel_jobs <= 0:
734729
raise ValueError("max_parallel_jobs must be greater than 0.")
735730

736-
if instance_count > 1 and not spark_config:
731+
if instance_count > 1 and not (
732+
(spark_config is not None and not use_torchrun)
733+
or (spark_config is None and use_torchrun)
734+
):
737735
raise ValueError(
738-
"Remote function do not support training on multi instances. "
736+
"Remote function do not support training on multi instances "
737+
+ "without spark_config or use_torchrun. "
739738
+ "Please provide instance_count = 1"
740739
)
741740

@@ -768,7 +767,6 @@ def __init__(
768767
use_spot_instances=use_spot_instances,
769768
max_wait_time_in_seconds=max_wait_time_in_seconds,
770769
use_torchrun=use_torchrun,
771-
nproc_per_node=nproc_per_node,
772770
)
773771

774772
self._state_condition = threading.Condition()

src/sagemaker/remote_function/core/stored_function.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ def __init__(
5555
hmac_key: str,
5656
s3_kms_key: str = None,
5757
context: Context = Context(),
58-
use_torchrun: bool = False,
59-
nproc_per_node: int = 1,
6058
):
6159
"""Construct a StoredFunction object.
6260
@@ -67,16 +65,12 @@ def __init__(
6765
s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
6866
hmac_key: Key used to encrypt serialized and deserialized function and arguments.
6967
context: Build or run context of a pipeline step.
70-
use_torchrun: Whether to use torchrun for distributed training.
71-
nproc_per_node: Number of processes per node for distributed training.
7268
"""
7369
self.sagemaker_session = sagemaker_session
7470
self.s3_base_uri = s3_base_uri
7571
self.s3_kms_key = s3_kms_key
7672
self.hmac_key = hmac_key
7773
self.context = context
78-
self.use_torchrun = use_torchrun
79-
self.nproc_per_node = nproc_per_node
8074

8175
self.func_upload_path = s3_path_join(
8276
s3_base_uri, context.step_name, context.func_step_s3_dir

src/sagemaker/remote_function/job.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,12 @@
130130
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
131131
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"
132132
133+
printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n"
134+
cat /opt/ml/input/config/resourceconfig.json
133135
134136
printf "INFO: Bootstraping runtime environment.\\n"
135137
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
138+
source /opt/ml/input/sm_training.env
136139
137140
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
138141
then
@@ -155,9 +158,13 @@
155158
fi
156159
157160
printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n"
161+
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\n"
162+
158163
$conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@"
159164
else
160165
printf "INFO: No conda env provided. Invoking remote function\\n"
166+
printf "INFO: python -m sagemaker.remote_function.invoke_function \\n"
167+
161168
python -m sagemaker.remote_function.invoke_function "$@"
162169
fi
163170
"""
@@ -175,9 +182,12 @@
175182
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
176183
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"
177184
185+
printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n"
186+
cat /opt/ml/input/config/resourceconfig.json
178187
179188
printf "INFO: Bootstraping runtime environment.\\n"
180189
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
190+
source /opt/ml/input/sm_training.env
181191
182192
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
183193
then
@@ -200,11 +210,20 @@
200210
fi
201211
202212
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
203-
$conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
213+
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
214+
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
215+
-m sagemaker.remote_function.invoke_function \\n"
216+
217+
$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
218+
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
204219
-m sagemaker.remote_function.invoke_function "$@"
205220
else
206221
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
207-
torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
222+
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
223+
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n"
224+
225+
torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
226+
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@"
208227
fi
209228
"""
210229

@@ -263,7 +282,6 @@ def __init__(
263282
use_spot_instances=False,
264283
max_wait_time_in_seconds=None,
265284
use_torchrun=False,
266-
nproc_per_node=1,
267285
):
268286
"""Initialize a _JobSettings instance which configures the remote job.
269287
@@ -604,7 +622,6 @@ def __init__(
604622
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)
605623

606624
self.use_torchrun = use_torchrun
607-
self.nproc_per_node = nproc_per_node
608625

609626
@staticmethod
610627
def _get_default_image(session):
@@ -732,6 +749,8 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
732749
)
733750

734751
logger.info("Creating job: %s", job_name)
752+
logger.info("Environment variables: %s", training_job_request["Environment"])
753+
735754
job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request)
736755

737756
return _Job(
@@ -776,8 +795,6 @@ def compile(
776795
s3_base_uri=s3_base_uri,
777796
hmac_key=hmac_key,
778797
s3_kms_key=job_settings.s3_kms_key,
779-
use_torchrun=job_settings.use_torchrun,
780-
nproc_per_node=job_settings.nproc_per_node,
781798
)
782799
stored_function.save(func, *func_args, **func_kwargs)
783800
else:
@@ -790,8 +807,6 @@ def compile(
790807
step_name=step_compilation_context.step_name,
791808
func_step_s3_dir=step_compilation_context.pipeline_build_time,
792809
),
793-
use_torchrun=job_settings.use_torchrun,
794-
nproc_per_node=job_settings.nproc_per_node,
795810
)
796811

797812
stored_function.save_pipeline_step_function(serialized_data)
@@ -931,6 +946,7 @@ def compile(
931946
request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key})
932947

933948
extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri)
949+
extended_request = _extend_torchrun_to_request(extended_request, job_settings)
934950

935951
return extended_request
936952

@@ -1011,7 +1027,6 @@ def _prepare_and_upload_runtime_scripts(
10111027
s3_kms_key: str,
10121028
sagemaker_session: Session,
10131029
use_torchrun: bool = False,
1014-
nproc_per_node: int = 1,
10151030
):
10161031
"""Copy runtime scripts to a folder and upload to S3.
10171032
@@ -1029,8 +1044,6 @@ def _prepare_and_upload_runtime_scripts(
10291044
sagemaker_session (str): SageMaker boto client session.
10301045
10311046
use_torchrun (bool): Whether to use torchrun or not.
1032-
1033-
nproc_per_node (int): Number of processes per node.
10341047
"""
10351048

10361049
from sagemaker.workflow.utilities import load_step_compilation_context
@@ -1054,7 +1067,6 @@ def _prepare_and_upload_runtime_scripts(
10541067

10551068
if use_torchrun:
10561069
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1057-
entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node))
10581070

10591071
with open(entrypoint_script_path, "w", newline="\n") as file:
10601072
file.writelines(entry_point_script)
@@ -1094,7 +1106,6 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
10941106
s3_kms_key=job_settings.s3_kms_key,
10951107
sagemaker_session=job_settings.sagemaker_session,
10961108
use_torchrun=job_settings.use_torchrun,
1097-
nproc_per_node=job_settings.nproc_per_node,
10981109
)
10991110

11001111
input_data_config = [
@@ -1435,6 +1446,35 @@ def _upload_serialized_spark_configuration(
14351446
return config_file_s3_uri
14361447

14371448

1449+
def _extend_torchrun_to_request(
1450+
request_dict: Dict,
1451+
job_settings: _JobSettings,
1452+
) -> Dict:
1453+
"""Extend the create training job request with torchrun configuration.
1454+
1455+
Args:
1456+
request_dict (Dict): create training job request dict.
1457+
job_settings (_JobSettings): the job settings.
1458+
"""
1459+
use_torchrun = job_settings.use_torchrun
1460+
instance_count = job_settings.instance_count
1461+
1462+
if not use_torchrun:
1463+
return request_dict
1464+
1465+
if instance_count == 1:
1466+
return request_dict
1467+
1468+
extended_request = request_dict.copy()
1469+
1470+
for input_channel in extended_request["InputDataConfig"]:
1471+
s3_data_source = input_channel["DataSource"].get("S3DataSource", None)
1472+
if s3_data_source:
1473+
s3_data_source["S3DataDistributionType"] = "FullyReplicated"
1474+
1475+
return extended_request
1476+
1477+
14381478
def _extend_spark_config_to_request(
14391479
request_dict: Dict,
14401480
job_settings: _JobSettings,

0 commit comments

Comments
 (0)