-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Added torchrun compatibility for distributet training across multiple GPUs in a single node (single instance) #4766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
fdbf6ba
c253d0b
6815adb
05b2c61
f1b99a4
fb3015f
73a1a62
f6840d1
60a421d
a61d042
e747737
6fce4d6
a508ebf
020f29b
634b8f6
ef92bcf
9681d91
7a31831
5ce1fbd
fb38454
1ad86dd
97c172e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -162,6 +162,51 @@ | |
fi | ||
""" | ||
|
||
ENTRYPOINT_TORCHRUN_SCRIPT = f""" | ||
#!/bin/bash | ||
|
||
# Entry point for bootstrapping runtime environment and invoking remote function with torchrun | ||
|
||
set -eu | ||
|
||
PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} | ||
export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs | ||
printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" | ||
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip | ||
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" | ||
|
||
|
||
printf "INFO: Bootstraping runtime environment.\\n" | ||
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" | ||
|
||
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] | ||
then | ||
if [ -f "remote_function_conda_env.txt" ] | ||
then | ||
cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt | ||
fi | ||
printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" | ||
cd {JOB_REMOTE_FUNCTION_WORKSPACE} | ||
fi | ||
|
||
if [ -f "remote_function_conda_env.txt" ] | ||
then | ||
conda_env=$(cat remote_function_conda_env.txt) | ||
|
||
if which mamba >/dev/null; then | ||
conda_exe="mamba" | ||
else | ||
conda_exe="conda" | ||
fi | ||
|
||
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" | ||
$conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line too long |
||
else | ||
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" | ||
torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@" | ||
fi | ||
""" | ||
|
||
SPARK_ENTRYPOINT_SCRIPT = f""" | ||
#!/bin/bash | ||
|
||
|
@@ -216,6 +261,8 @@ def __init__( | |
spark_config: SparkConfig = None, | ||
use_spot_instances=False, | ||
max_wait_time_in_seconds=None, | ||
use_torchrun=False, | ||
nproc_per_node=1, | ||
): | ||
"""Initialize a _JobSettings instance which configures the remote job. | ||
|
||
|
@@ -555,6 +602,9 @@ def __init__( | |
tags = format_tags(tags) | ||
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS) | ||
|
||
self.use_torchrun = use_torchrun | ||
self.nproc_per_node = nproc_per_node | ||
|
||
@staticmethod | ||
def _get_default_image(session): | ||
"""Return Studio notebook image, if in Studio env. Else, base python. | ||
|
@@ -951,7 +1001,12 @@ def _get_job_name(job_settings, func): | |
|
||
|
||
def _prepare_and_upload_runtime_scripts( | ||
spark_config: SparkConfig, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session | ||
spark_config: SparkConfig, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Too much whitespace, looks like two tab instead of 1 maybe |
||
s3_base_uri: str, | ||
s3_kms_key: str, | ||
sagemaker_session: Session, | ||
use_torchrun: bool = False, | ||
nproc_per_node: int = 1, | ||
): | ||
"""Copy runtime scripts to a folder and upload to S3. | ||
|
||
|
@@ -988,6 +1043,10 @@ def _prepare_and_upload_runtime_scripts( | |
) | ||
shutil.copy2(spark_script_path, bootstrap_scripts) | ||
|
||
if use_torchrun: | ||
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT | ||
entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node)) | ||
|
||
with open(entrypoint_script_path, "w", newline="\n") as file: | ||
file.writelines(entry_point_script) | ||
|
||
|
@@ -1025,6 +1084,8 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): | |
s3_base_uri=s3_base_uri, | ||
s3_kms_key=job_settings.s3_kms_key, | ||
sagemaker_session=job_settings.sagemaker_session, | ||
use_torchrun=job_settings.use_torchrun, | ||
nproc_per_node=job_settings.nproc_per_node, | ||
) | ||
|
||
input_data_config = [ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -818,3 +818,25 @@ def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container): | |
f"--rm {auto_capture_test_container}" | ||
) | ||
subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and a new line here |
||
def test_decorator_torchrun( | ||
sagemaker_session, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, needs 1 tab instead of 2. |
||
dummy_container_without_error, | ||
gpu_instance_type, | ||
use_torchrun=True, | ||
nproc_per_node=2, | ||
): | ||
@remote( | ||
role=ROLE, | ||
image_uri=dummy_container_without_error, | ||
instance_type=gpu_instance_type, | ||
sagemaker_session=sagemaker_session, | ||
keep_alive_period_in_seconds=60, | ||
use_torchrun=use_torchrun, | ||
nproc_per_node=nproc_per_node, | ||
) | ||
def divide(x, y): | ||
return x / y | ||
|
||
assert divide(10, 2) == 5 | ||
assert divide(20, 2) == 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a new line here