Skip to content

Commit fc0d037

Browse files
authored
Merge branch 'master' into dependabot/pip/requirements/extras/scikit-learn-1.5.0
2 parents 4871c4d + a870e19 commit fc0d037

File tree

17 files changed

+501
-14
lines changed

17 files changed

+501
-14
lines changed

src/sagemaker/clarify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ class BiasConfig:
870870

871871
def __init__(
872872
self,
873-
label_values_or_threshold: Union[int, float, str],
873+
label_values_or_threshold: List[Union[int, float, str]],
874874
facet_name: Union[str, int, List[str], List[int]],
875875
facet_values_or_threshold: Optional[Union[int, float, str]] = None,
876876
group_name: Optional[str] = None,

src/sagemaker/image_uri_config/model-monitor.json

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414
"ap-southeast-1": "245545462676",
1515
"ap-southeast-2": "563025443158",
1616
"ap-southeast-3": "669540362728",
17-
"ap-southeast-5": "654654579213",
1817
"ca-central-1": "536280801234",
1918
"cn-north-1": "453000072557",
2019
"cn-northwest-1": "453252182341",
2120
"eu-central-1": "048819808253",
22-
"eu-central-2": "590183933784",
2321
"eu-north-1": "895015795356",
2422
"eu-south-1": "933208885752",
2523
"eu-south-2": "437450045455",

src/sagemaker/remote_function/client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def remote(
9090
spark_config: SparkConfig = None,
9191
use_spot_instances=False,
9292
max_wait_time_in_seconds=None,
93+
use_torchrun=False,
94+
nproc_per_node=1,
9395
):
9496
"""Decorator for running the annotated function as a SageMaker training job.
9597
@@ -278,6 +280,12 @@ def remote(
278280
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
279281
After this amount of time Amazon SageMaker will stop waiting for managed spot training
280282
job to complete. Defaults to ``None``.
283+
284+
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
285+
Defaults to ``False``.
286+
287+
nproc_per_node (int): Specifies the number of processes per node for distributed training.
288+
Defaults to ``1``.
281289
"""
282290

283291
def _remote(func):
@@ -310,6 +318,8 @@ def _remote(func):
310318
spark_config=spark_config,
311319
use_spot_instances=use_spot_instances,
312320
max_wait_time_in_seconds=max_wait_time_in_seconds,
321+
use_torchrun=use_torchrun,
322+
nproc_per_node=nproc_per_node,
313323
)
314324

315325
@functools.wraps(func)
@@ -521,6 +531,8 @@ def __init__(
521531
spark_config: SparkConfig = None,
522532
use_spot_instances=False,
523533
max_wait_time_in_seconds=None,
534+
use_torchrun=False,
535+
nproc_per_node=1,
524536
):
525537
"""Constructor for RemoteExecutor
526538
@@ -709,6 +721,12 @@ def __init__(
709721
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
710722
After this amount of time Amazon SageMaker will stop waiting for managed spot training
711723
job to complete. Defaults to ``None``.
724+
725+
use_torchrun (bool): Specifies whether to use torchrun for distributed training.
726+
Defaults to ``False``.
727+
728+
nproc_per_node (int): Specifies the number of processes per node.
729+
Defaults to ``1``.
712730
"""
713731
self.max_parallel_jobs = max_parallel_jobs
714732

@@ -749,6 +767,8 @@ def __init__(
749767
spark_config=spark_config,
750768
use_spot_instances=use_spot_instances,
751769
max_wait_time_in_seconds=max_wait_time_in_seconds,
770+
use_torchrun=use_torchrun,
771+
nproc_per_node=nproc_per_node,
752772
)
753773

754774
self._state_condition = threading.Condition()

src/sagemaker/remote_function/core/stored_function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def __init__(
5555
hmac_key: str,
5656
s3_kms_key: str = None,
5757
context: Context = Context(),
58+
use_torchrun: bool = False,
59+
nproc_per_node: int = 1,
5860
):
5961
"""Construct a StoredFunction object.
6062
@@ -65,12 +67,16 @@ def __init__(
6567
s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
6668
hmac_key: Key used to encrypt serialized and deserialized function and arguments.
6769
context: Build or run context of a pipeline step.
70+
use_torchrun: Whether to use torchrun for distributed training.
71+
nproc_per_node: Number of processes per node for distributed training.
6872
"""
6973
self.sagemaker_session = sagemaker_session
7074
self.s3_base_uri = s3_base_uri
7175
self.s3_kms_key = s3_kms_key
7276
self.hmac_key = hmac_key
7377
self.context = context
78+
self.use_torchrun = use_torchrun
79+
self.nproc_per_node = nproc_per_node
7480

7581
self.func_upload_path = s3_path_join(
7682
s3_base_uri, context.step_name, context.func_step_s3_dir

src/sagemaker/remote_function/job.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,52 @@
162162
fi
163163
"""
164164

165+
ENTRYPOINT_TORCHRUN_SCRIPT = f"""
166+
#!/bin/bash
167+
168+
# Entry point for bootstrapping runtime environment and invoking remote function with torchrun
169+
170+
set -eu
171+
172+
PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}}
173+
export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs
174+
printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n"
175+
export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip
176+
printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n"
177+
178+
179+
printf "INFO: Bootstraping runtime environment.\\n"
180+
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
181+
182+
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
183+
then
184+
if [ -f "remote_function_conda_env.txt" ]
185+
then
186+
cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt
187+
fi
188+
printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n"
189+
cd {JOB_REMOTE_FUNCTION_WORKSPACE}
190+
fi
191+
192+
if [ -f "remote_function_conda_env.txt" ]
193+
then
194+
conda_env=$(cat remote_function_conda_env.txt)
195+
196+
if which mamba >/dev/null; then
197+
conda_exe="mamba"
198+
else
199+
conda_exe="conda"
200+
fi
201+
202+
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
203+
$conda_exe run -n $conda_env torchrun --nproc_per_node $NPROC_PER_NODE \
204+
-m sagemaker.remote_function.invoke_function "$@"
205+
else
206+
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
207+
torchrun --nproc_per_node $NPROC_PER_NODE -m sagemaker.remote_function.invoke_function "$@"
208+
fi
209+
"""
210+
165211
SPARK_ENTRYPOINT_SCRIPT = f"""
166212
#!/bin/bash
167213
@@ -216,6 +262,8 @@ def __init__(
216262
spark_config: SparkConfig = None,
217263
use_spot_instances=False,
218264
max_wait_time_in_seconds=None,
265+
use_torchrun=False,
266+
nproc_per_node=1,
219267
):
220268
"""Initialize a _JobSettings instance which configures the remote job.
221269
@@ -555,6 +603,9 @@ def __init__(
555603
tags = format_tags(tags)
556604
self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS)
557605

606+
self.use_torchrun = use_torchrun
607+
self.nproc_per_node = nproc_per_node
608+
558609
@staticmethod
559610
def _get_default_image(session):
560611
"""Return Studio notebook image, if in Studio env. Else, base python.
@@ -725,6 +776,8 @@ def compile(
725776
s3_base_uri=s3_base_uri,
726777
hmac_key=hmac_key,
727778
s3_kms_key=job_settings.s3_kms_key,
779+
use_torchrun=job_settings.use_torchrun,
780+
nproc_per_node=job_settings.nproc_per_node,
728781
)
729782
stored_function.save(func, *func_args, **func_kwargs)
730783
else:
@@ -737,6 +790,8 @@ def compile(
737790
step_name=step_compilation_context.step_name,
738791
func_step_s3_dir=step_compilation_context.pipeline_build_time,
739792
),
793+
use_torchrun=job_settings.use_torchrun,
794+
nproc_per_node=job_settings.nproc_per_node,
740795
)
741796

742797
stored_function.save_pipeline_step_function(serialized_data)
@@ -951,7 +1006,12 @@ def _get_job_name(job_settings, func):
9511006

9521007

9531008
def _prepare_and_upload_runtime_scripts(
954-
spark_config: SparkConfig, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session
1009+
spark_config: SparkConfig,
1010+
s3_base_uri: str,
1011+
s3_kms_key: str,
1012+
sagemaker_session: Session,
1013+
use_torchrun: bool = False,
1014+
nproc_per_node: int = 1,
9551015
):
9561016
"""Copy runtime scripts to a folder and upload to S3.
9571017
@@ -967,6 +1027,10 @@ def _prepare_and_upload_runtime_scripts(
9671027
s3_kms_key (str): kms key used to encrypt the files uploaded to S3.
9681028
9691029
sagemaker_session (str): SageMaker boto client session.
1030+
1031+
use_torchrun (bool): Whether to use torchrun or not.
1032+
1033+
nproc_per_node (int): Number of processes per node.
9701034
"""
9711035

9721036
from sagemaker.workflow.utilities import load_step_compilation_context
@@ -988,6 +1052,10 @@ def _prepare_and_upload_runtime_scripts(
9881052
)
9891053
shutil.copy2(spark_script_path, bootstrap_scripts)
9901054

1055+
if use_torchrun:
1056+
entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT
1057+
entry_point_script = entry_point_script.replace("$NPROC_PER_NODE", str(nproc_per_node))
1058+
9911059
with open(entrypoint_script_path, "w", newline="\n") as file:
9921060
file.writelines(entry_point_script)
9931061

@@ -1025,6 +1093,8 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
10251093
s3_base_uri=s3_base_uri,
10261094
s3_kms_key=job_settings.s3_kms_key,
10271095
sagemaker_session=job_settings.sagemaker_session,
1096+
use_torchrun=job_settings.use_torchrun,
1097+
nproc_per_node=job_settings.nproc_per_node,
10281098
)
10291099

10301100
input_data_config = [

src/sagemaker/serve/builder/model_builder.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from sagemaker.serve.mode.function_pointers import Mode
3737
from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode
3838
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
39+
from sagemaker.serve.mode.in_process_mode import InProcessMode
3940
from sagemaker.serve.detector.pickler import save_pkl, save_xgboost
4041
from sagemaker.serve.builder.serve_settings import _ServeSettings
4142
from sagemaker.serve.builder.djl_builder import DJL
@@ -410,7 +411,7 @@ def _prepare_for_mode(
410411
)
411412
self.env_vars.update(env_vars_sagemaker)
412413
return self.s3_upload_path, env_vars_sagemaker
413-
if self.mode == Mode.LOCAL_CONTAINER:
414+
elif self.mode == Mode.LOCAL_CONTAINER:
414415
# init the LocalContainerMode object
415416
self.modes[str(Mode.LOCAL_CONTAINER)] = LocalContainerMode(
416417
inference_spec=self.inference_spec,
@@ -422,9 +423,22 @@ def _prepare_for_mode(
422423
)
423424
self.modes[str(Mode.LOCAL_CONTAINER)].prepare()
424425
return None
426+
elif self.mode == Mode.IN_PROCESS:
427+
# init the InProcessMode object
428+
self.modes[str(Mode.IN_PROCESS)] = InProcessMode(
429+
inference_spec=self.inference_spec,
430+
schema_builder=self.schema_builder,
431+
session=self.sagemaker_session,
432+
model_path=self.model_path,
433+
env_vars=self.env_vars,
434+
model_server=self.model_server,
435+
)
436+
self.modes[str(Mode.IN_PROCESS)].prepare()
437+
return None
425438

426439
raise ValueError(
427-
"Please specify mode in: %s, %s" % (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT)
440+
"Please specify mode in: %s, %s, %s"
441+
% (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS)
428442
)
429443

430444
def _get_client_translators(self):
@@ -606,6 +620,9 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str):
606620
elif overwrite_mode == Mode.LOCAL_CONTAINER:
607621
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
608622
self._prepare_for_mode()
623+
elif overwrite_mode == Mode.IN_PROCESS:
624+
self.mode = self.pysdk_model.mode = Mode.IN_PROCESS
625+
self._prepare_for_mode()
609626
else:
610627
raise ValueError("Mode %s is not supported!" % overwrite_mode)
611628

@@ -795,9 +812,10 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None:
795812
self.dependencies.update({"requirements": mlflow_model_dependency_path})
796813

797814
# Model Builder is a class to build the model for deployment.
798-
# It supports two modes of deployment
815+
# It supports two* modes of deployment
799816
# 1/ SageMaker Endpoint
800817
# 2/ Local launch with container
818+
# 3/ In process mode with Transformers server in beta release
801819
def build( # pylint: disable=R0911
802820
self,
803821
mode: Type[Mode] = None,
@@ -895,8 +913,10 @@ def build( # pylint: disable=R0911
895913

896914
def _build_validations(self):
897915
"""Validations needed for model server overrides, or auto-detection or fallback"""
898-
if self.mode == Mode.IN_PROCESS:
899-
raise ValueError("IN_PROCESS mode is not supported yet!")
916+
if self.mode == Mode.IN_PROCESS and self.model_server is not ModelServer.MMS:
917+
raise ValueError(
918+
"IN_PROCESS mode is only supported for MMS/Transformers server in beta release."
919+
)
900920

901921
if self.inference_spec and self.model:
902922
raise ValueError("Can only set one of the following: model, inference_spec.")

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
)
3737
from sagemaker.serve.detector.pickler import save_pkl
3838
from sagemaker.serve.utils.optimize_utils import _is_optimized
39-
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
39+
from sagemaker.serve.utils.predictors import (
40+
TransformersLocalModePredictor,
41+
TransformersInProcessModePredictor,
42+
)
4043
from sagemaker.serve.utils.types import ModelServer
4144
from sagemaker.serve.mode.function_pointers import Mode
4245
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
@@ -47,6 +50,7 @@
4750

4851
logger = logging.getLogger(__name__)
4952
DEFAULT_TIMEOUT = 1800
53+
LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS]
5054

5155

5256
"""Retrieves images for different libraries - Pytorch, TensorFlow from HuggingFace hub
@@ -228,6 +232,18 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
228232
)
229233
return predictor
230234

235+
if self.mode == Mode.IN_PROCESS:
236+
timeout = kwargs.get("model_data_download_timeout")
237+
238+
predictor = TransformersInProcessModePredictor(
239+
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
240+
)
241+
242+
self.modes[str(Mode.IN_PROCESS)].create_server(
243+
predictor,
244+
)
245+
return predictor
246+
231247
self._set_instance(kwargs)
232248

233249
if "mode" in kwargs:
@@ -293,7 +309,7 @@ def _build_transformers_env(self):
293309

294310
self.pysdk_model = self._create_transformers_model()
295311

296-
if self.mode == Mode.LOCAL_CONTAINER:
312+
if self.mode in LOCAL_MODES:
297313
self._prepare_for_mode()
298314

299315
return self.pysdk_model

0 commit comments

Comments
 (0)