Skip to content

Commit adcc38e

Browse files
committed
refactor nproc_per_node for backwards compatibility
1 parent 423c585 commit adcc38e

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def remote(
9191
use_spot_instances=False,
9292
max_wait_time_in_seconds=None,
9393
use_torchrun=False,
94-
nproc_per_node=1,
94+
nproc_per_node: Optional[int] = None,
9595
):
9696
"""Decorator for running the annotated function as a SageMaker training job.
9797
@@ -284,8 +284,8 @@ def remote(
284284
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
285285
Defaults to ``False``.
286286
287-
nproc_per_node (int): Specifies the number of processes per node for
288-
distributed training. Defaults to ``1``.
287+
nproc_per_node (Optional int): Specifies the number of processes per node for
288+
distributed training. Defaults to ``None``.
289289
This is defined automatically configured on the instance type.
290290
"""
291291

@@ -320,6 +320,7 @@ def _remote(func):
320320
use_spot_instances=use_spot_instances,
321321
max_wait_time_in_seconds=max_wait_time_in_seconds,
322322
use_torchrun=use_torchrun,
323+
nproc_per_node=nproc_per_node,
323324
)
324325

325326
@functools.wraps(func)
@@ -536,7 +537,7 @@ def __init__(
536537
use_spot_instances=False,
537538
max_wait_time_in_seconds=None,
538539
use_torchrun=False,
539-
nproc_per_node=1,
540+
nproc_per_node: Optional[int] = None,
540541
):
541542
"""Constructor for RemoteExecutor
542543
@@ -729,8 +730,8 @@ def __init__(
729730
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
730731
Defaults to ``False``.
731732
732-
nproc_per_node (int): Specifies the number of processes per node for
733-
distributed training. Defaults to ``1``.
733+
nproc_per_node (Optional int): Specifies the number of processes per node for
734+
distributed training. Defaults to ``None``.
734735
This is defined automatically configured on the instance type.
735736
"""
736737
self.max_parallel_jobs = max_parallel_jobs
@@ -777,6 +778,7 @@ def __init__(
777778
use_spot_instances=use_spot_instances,
778779
max_wait_time_in_seconds=max_wait_time_in_seconds,
779780
use_torchrun=use_torchrun,
781+
nproc_per_node=nproc_per_node,
780782
)
781783

782784
self._state_condition = threading.Condition()

src/sagemaker/remote_function/job.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def __init__(
282282
use_spot_instances=False,
283283
max_wait_time_in_seconds=None,
284284
use_torchrun: bool = False,
285+
nproc_per_node: Optional[int] = None,
285286
):
286287
"""Initialize a _JobSettings instance which configures the remote job.
287288
@@ -463,6 +464,13 @@ def __init__(
463464
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
464465
After this amount of time Amazon SageMaker will stop waiting for managed spot
465466
training job to complete. Defaults to ``None``.
467+
468+
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
469+
Defaults to ``False``.
470+
471+
nproc_per_node (Optional int): Specifies the number of processes per node for
472+
distributed training. Defaults to ``None``.
473+
This is defined automatically configured on the instance type.
466474
"""
467475
self.sagemaker_session = sagemaker_session or Session()
468476
self.environment_variables = resolve_value_from_config(
@@ -622,6 +630,7 @@ def __init__(
622630
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)
623631

624632
self.use_torchrun = use_torchrun
633+
self.nproc_per_node = nproc_per_node
625634

626635
@staticmethod
627636
def _get_default_image(session):
@@ -749,7 +758,6 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
749758
)
750759

751760
logger.info("Creating job: %s", job_name)
752-
logger.info("Environment variables: %s", training_job_request["Environment"])
753761

754762
job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request)
755763

@@ -1027,6 +1035,7 @@ def _prepare_and_upload_runtime_scripts(
10271035
s3_kms_key: str,
10281036
sagemaker_session: Session,
10291037
use_torchrun: bool = False,
1038+
nproc_per_node: Optional[int] = None,
10301039
):
10311040
"""Copy runtime scripts to a folder and upload to S3.
10321041
@@ -1044,6 +1053,8 @@ def _prepare_and_upload_runtime_scripts(
10441053
sagemaker_session (str): SageMaker boto client session.
10451054
10461055
use_torchrun (bool): Whether to use torchrun or not.
1056+
1057+
nproc_per_node (Optional[int]): Number of processes per node
10471058
"""
10481059

10491060
from sagemaker.workflow.utilities import load_step_compilation_context
@@ -1068,6 +1079,12 @@ def _prepare_and_upload_runtime_scripts(
10681079
if use_torchrun:
10691080
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
10701081

1082+
if nproc_per_node is not None and nproc_per_node > 0:
1083+
entry_point_script = entry_point_script.replace(
1084+
"$SM_NPROC_PER_NODE",
1085+
str(nproc_per_node)
1086+
)
1087+
10711088
with open(entrypoint_script_path, "w", newline="\n") as file:
10721089
file.writelines(entry_point_script)
10731090

@@ -1106,6 +1123,7 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
11061123
s3_kms_key=job_settings.s3_kms_key,
11071124
sagemaker_session=job_settings.sagemaker_session,
11081125
use_torchrun=job_settings.use_torchrun,
1126+
nproc_per_node=job_settings.nproc_per_node,
11091127
)
11101128

11111129
input_data_config = [

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def test_start(
390390
s3_kms_key=None,
391391
sagemaker_session=session(),
392392
use_torchrun=False,
393+
nproc_per_node=None,
393394
)
394395

395396
mock_dependency_upload.assert_called_once_with(
@@ -672,6 +673,7 @@ def test_start_with_complete_job_settings(
672673
s3_kms_key=job_settings.s3_kms_key,
673674
sagemaker_session=session(),
674675
use_torchrun=False,
676+
nproc_per_node=None,
675677
)
676678

677679
mock_user_workspace_upload.assert_called_once_with(
@@ -843,6 +845,7 @@ def test_get_train_args_under_pipeline_context(
843845
s3_kms_key=job_settings.s3_kms_key,
844846
sagemaker_session=session(),
845847
use_torchrun=False,
848+
nproc_per_node=None,
846849
)
847850

848851
mock_user_workspace_upload.assert_called_once_with(
@@ -1018,6 +1021,7 @@ def test_start_with_spark(
10181021
s3_kms_key=None,
10191022
sagemaker_session=session(),
10201023
use_torchrun=False,
1024+
nproc_per_node=None,
10211025
)
10221026

10231027
session().sagemaker_client.create_training_job.assert_called_once_with(
@@ -1633,6 +1637,7 @@ def test_start_with_torchrun_single_node(
16331637
instance_type="ml.g5.12xlarge",
16341638
encrypt_inter_container_traffic=True,
16351639
use_torchrun=True,
1640+
nproc_per_node=None,
16361641
)
16371642

16381643
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
@@ -1658,6 +1663,7 @@ def test_start_with_torchrun_single_node(
16581663
s3_kms_key=None,
16591664
sagemaker_session=session(),
16601665
use_torchrun=True,
1666+
nproc_per_node=None,
16611667
)
16621668

16631669
mock_dependency_upload.assert_called_once_with(
@@ -1759,6 +1765,7 @@ def test_start_with_torchrun_multi_node(
17591765
instance_type="ml.g5.2xlarge",
17601766
encrypt_inter_container_traffic=True,
17611767
use_torchrun=True,
1768+
nproc_per_node=None,
17621769
)
17631770

17641771
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
@@ -1784,6 +1791,7 @@ def test_start_with_torchrun_multi_node(
17841791
s3_kms_key=None,
17851792
sagemaker_session=session(),
17861793
use_torchrun=True,
1794+
nproc_per_node=None,
17871795
)
17881796

17891797
mock_dependency_upload.assert_called_once_with(

0 commit comments

Comments
 (0)