diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 53a116e4ef..73a308ddf5 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -90,6 +90,8 @@ def remote( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, + use_torchrun=False, + nproc_per_node=1, ): """Decorator for running the annotated function as a SageMaker training job. @@ -278,6 +280,12 @@ def remote( max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete. Defaults to ``None``. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Specifies the number of processes per node for distributed training. + Defaults to ``1``. """ def _remote(func): @@ -310,6 +318,8 @@ def _remote(func): spark_config=spark_config, use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, + use_torchrun=use_torchrun, + nproc_per_node=nproc_per_node, ) @functools.wraps(func) @@ -521,6 +531,8 @@ def __init__( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, + use_torchrun=False, + nproc_per_node=1, ): """Constructor for RemoteExecutor @@ -709,6 +721,12 @@ def __init__( max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete. Defaults to ``None``. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Specifies the number of processes per node. + Defaults to ``1``. """ self.max_parallel_jobs = max_parallel_jobs @@ -749,6 +767,8 @@ def __init__( spark_config=spark_config, use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, + use_torchrun=use_torchrun, + nproc_per_node=nproc_per_node, ) self._state_condition = threading.Condition() diff --git a/src/sagemaker/remote_function/core/stored_function.py b/src/sagemaker/remote_function/core/stored_function.py index 862c67d9ee..ade4a9e652 100644 --- a/src/sagemaker/remote_function/core/stored_function.py +++ b/src/sagemaker/remote_function/core/stored_function.py @@ -55,6 +55,8 @@ def __init__( hmac_key: str, s3_kms_key: str = None, context: Context = Context(), + use_torchrun: bool = False, + nproc_per_node: int = 1, ): """Construct a StoredFunction object. @@ -65,12 +67,16 @@ def __init__( s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. hmac_key: Key used to encrypt serialized and deserialized function and arguments. context: Build or run context of a pipeline step. + use_torchrun: Whether to use torchrun for distributed training. + nproc_per_node: Number of processes per node for distributed training. """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key self.hmac_key = hmac_key self.context = context + self.use_torchrun = use_torchrun + self.nproc_per_node = nproc_per_node self.func_upload_path = s3_path_join( s3_base_uri, context.step_name, context.func_step_s3_dir diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 5814ee45ff..8ab4d420e5 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -162,6 +162,52 @@ fi """ +ENTRYPOINT_TORCHRUN_SCRIPT = f""" +#!/bin/bash + +# Entry point for bootstrapping runtime environment and invoking remote function with torchrun + +set -eu + +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" + + +printf "INFO: Bootstraping runtime environment.\\n" +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" + +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] +then + if [ -f "remote_function_conda_env.txt" ] + then + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt + fi + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" + cd {JOB_REMOTE_FUNCTION_WORKSPACE} +fi + +if [ -f "remote_function_conda_env.txt" ] +then + conda_env=$(cat remote_function_conda_env.txt) + + if which mamba >/dev/null; then + conda_exe="mamba" + else + conda_exe="conda" + fi + + printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" + $conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \ + -m sagemaker.remote_function.invoke_function "$@" +else + printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" + torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@" +fi +""" + SPARK_ENTRYPOINT_SCRIPT = f""" #!/bin/bash @@ -216,6 +262,8 @@ def __init__( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, + use_torchrun=False, + nproc_per_node=1, ): """Initialize a _JobSettings instance which configures the remote job. @@ -555,6 +603,9 @@ def __init__( tags = format_tags(tags) self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS) + self.use_torchrun = use_torchrun + self.nproc_per_node = nproc_per_node + @staticmethod def _get_default_image(session): """Return Studio notebook image, if in Studio env. Else, base python. @@ -725,6 +776,8 @@ def compile( s3_base_uri=s3_base_uri, hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, + use_torchrun=job_settings.use_torchrun, + nproc_per_node=job_settings.nproc_per_node, ) stored_function.save(func, *func_args, **func_kwargs) else: @@ -737,6 +790,8 @@ def compile( step_name=step_compilation_context.step_name, func_step_s3_dir=step_compilation_context.pipeline_build_time, ), + use_torchrun=job_settings.use_torchrun, + nproc_per_node=job_settings.nproc_per_node, ) stored_function.save_pipeline_step_function(serialized_data) @@ -951,7 +1006,12 @@ def _get_job_name(job_settings, func): def _prepare_and_upload_runtime_scripts( - spark_config: SparkConfig, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session + spark_config: SparkConfig, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + use_torchrun: bool = False, + nproc_per_node: int = 1, ): """Copy runtime scripts to a folder and upload to S3. @@ -967,6 +1027,10 @@ def _prepare_and_upload_runtime_scripts( s3_kms_key (str): kms key used to encrypt the files uploaded to S3. sagemaker_session (str): SageMaker boto client session. + + use_torchrun (bool): Whether to use torchrun or not. + + nproc_per_node (int): Number of processes per node. """ from sagemaker.workflow.utilities import load_step_compilation_context @@ -988,6 +1052,10 @@ def _prepare_and_upload_runtime_scripts( ) shutil.copy2(spark_script_path, bootstrap_scripts) + if use_torchrun: + entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT + entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node)) + with open(entrypoint_script_path, "w", newline="\n") as file: file.writelines(entry_point_script) @@ -1025,6 +1093,8 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): s3_base_uri=s3_base_uri, s3_kms_key=job_settings.s3_kms_key, sagemaker_session=job_settings.sagemaker_session, + use_torchrun=job_settings.use_torchrun, + nproc_per_node=job_settings.nproc_per_node, ) input_data_config = [ diff --git a/tests/integ/sagemaker/remote_function/test_decorator.py b/tests/integ/sagemaker/remote_function/test_decorator.py index 63ced1dd9c..2717bb9afe 100644 --- a/tests/integ/sagemaker/remote_function/test_decorator.py +++ b/tests/integ/sagemaker/remote_function/test_decorator.py @@ -818,3 +818,26 @@ def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container): f"--rm {auto_capture_test_container}" ) subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8") + + +def test_decorator_torchrun( + sagemaker_session, + dummy_container_without_error, + gpu_instance_type, + use_torchrun=False, + nproc_per_node=1, +): + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=gpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=60, + use_torchrun=use_torchrun, + nproc_per_node=nproc_per_node, + ) + def divide(x, y): + return x / y + + assert divide(10, 2) == 5 + assert divide(20, 2) == 10 diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py index 9020a9f05f..57f4a54f78 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py @@ -907,6 +907,8 @@ def test_remote_decorator_fields_consistency(get_execution_role, session): "use_spot_instances", "max_wait_time_in_seconds", "custom_file_filter", + "use_torchrun", + "nproc_per_node", } job_settings = _JobSettings( diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index 1d752f89ed..536bfdfca7 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -1504,6 +1504,8 @@ def test_consistency_between_remote_and_step_decorator(): "s3_kms_key", "s3_root_uri", "sagemaker_session", + "use_torchrun", + "nproc_per_node", ] step_args_to_ignore = ["_step", "name", "display_name", "description", "retry_policies"] diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index 98961ad80d..888c634bfe 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -376,6 +376,8 @@ def test_start( s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, s3_kms_key=None, + use_torchrun=False, + nproc_per_node=1, ) mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) @@ -389,6 +391,8 @@ def test_start( s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None, sagemaker_session=session(), + use_torchrun=False, + nproc_per_node=1, ) mock_dependency_upload.assert_called_once_with( @@ -506,6 +510,8 @@ def test_start_with_checkpoint_location( s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, s3_kms_key=None, + use_torchrun=False, + nproc_per_node=1, ) mock_stored_function().save.assert_called_once_with( @@ -659,6 +665,8 @@ def test_start_with_complete_job_settings( s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, s3_kms_key=KMS_KEY_ARN, + use_torchrun=False, + nproc_per_node=1, ) local_dependencies_path = mock_runtime_manager().snapshot() @@ -670,6 +678,8 @@ def test_start_with_complete_job_settings( s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), + use_torchrun=False, + nproc_per_node=1, ) mock_user_workspace_upload.assert_called_once_with( @@ -828,6 +838,8 @@ def test_get_train_args_under_pipeline_context( step_name=MOCKED_PIPELINE_CONFIG.step_name, func_step_s3_dir=MOCKED_PIPELINE_CONFIG.pipeline_build_time, ), + use_torchrun=False, + nproc_per_node=1, ) mock_stored_function.save_pipeline_step_function.assert_called_once_with(mocked_serialized_data) @@ -840,6 +852,8 @@ def test_get_train_args_under_pipeline_context( s3_base_uri=s3_base_uri, s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), + use_torchrun=False, + nproc_per_node=1, ) mock_user_workspace_upload.assert_called_once_with( @@ -1014,6 +1028,8 @@ def test_start_with_spark( s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None, sagemaker_session=session(), + use_torchrun=False, + nproc_per_node=1, ) session().sagemaker_client.create_training_job.assert_called_once_with( @@ -1168,6 +1184,8 @@ def test_prepare_and_upload_runtime_scripts(session, mock_copy, mock_s3_upload): s3_base_uri=S3_URI, s3_kms_key=KMS_KEY_ARN, sagemaker_session=session(), + use_torchrun=False, + nproc_per_node=1, ) assert s3_path == mock_s3_upload.return_value