Skip to content

Commit 69304c2

Browse files
authored
Merge branch 'master' into fastapi-inprocess
2 parents 88ee686 + e240518 commit 69304c2

31 files changed

+538
-470
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ max-returns=6
384384
max-branches=12
385385

386386
# Maximum number of statements in function / method body
387-
max-statements=100
387+
max-statements=105
388388

389389
# Maximum number of parents for a class (see R0901).
390390
max-parents=7

src/sagemaker/algorithm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,20 @@ def __init__(
157157
available (default: ``None``).
158158
**kwargs: Additional kwargs. This is unused. It's only added for AlgorithmEstimator
159159
to ignore the irrelevant arguments.
160+
161+
Raises:
162+
ValueError:
163+
- If an AWS IAM Role is not provided.
164+
- Bad value for instance type.
165+
RuntimeError:
166+
- When setting up custom VPC, both subnets and security_group_ids are not provided
167+
- If instance_count > 1 (distributed training) with instance type local or local gpu
168+
- If LocalSession is not used with instance type local or local gpu
169+
- file:// output path used outside of local mode
170+
botocore.exceptions.ClientError:
171+
- algorithm arn is incorrect
172+
- insufficient permission to access/ describe algorithm
173+
- algorithm is in a different region
160174
"""
161175
self.algorithm_arn = algorithm_arn
162176
super(AlgorithmEstimator, self).__init__(

src/sagemaker/base_predictor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,8 @@ def update_endpoint(
430430
- If ``initial_instance_count``, ``instance_type``, or ``accelerator_type`` is
431431
specified and either ``model_name`` is ``None`` or there are multiple models
432432
associated with the endpoint.
433+
botocore.exceptions.ClientError: If SageMaker throws an error while creating
434+
endpoint config, describing endpoint or updating endpoint
433435
"""
434436
production_variants = None
435437
current_model_names = self._get_model_names()

src/sagemaker/environment_variables.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
2222
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23-
from sagemaker.jumpstart.enums import JumpStartScriptScope
23+
from sagemaker.jumpstart.enums import JumpStartModelType, JumpStartScriptScope
2424
from sagemaker.session import Session
2525

2626
logger = logging.getLogger(__name__)
@@ -38,6 +38,7 @@ def retrieve_default(
3838
instance_type: Optional[str] = None,
3939
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
4040
config_name: Optional[str] = None,
41+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4142
) -> Dict[str, str]:
4243
"""Retrieves the default container environment variables for the model matching the arguments.
4344
@@ -70,6 +71,8 @@ def retrieve_default(
7071
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
7172
variables.
7273
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
74+
model_type (JumpStartModelType): The type of the model, can be open weights model
75+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
7376
Returns:
7477
dict: The variables to use for the model.
7578
@@ -94,4 +97,5 @@ def retrieve_default(
9497
instance_type=instance_type,
9598
script=script,
9699
config_name=config_name,
100+
model_type=model_type,
97101
)

src/sagemaker/estimator.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -590,25 +590,36 @@ def __init__(
590590
self.dependencies = dependencies or []
591591
self.uploaded_code: Optional[UploadedCode] = None
592592

593-
# Check that the user properly sets both subnet and secutiry_groupe_ids
593+
# Check that the user properly sets both subnet and security_group_ids
594594
if (
595595
subnets is not None
596596
and security_group_ids is None
597597
or security_group_ids is not None
598598
and subnets is None
599599
):
600+
troubleshooting = (
601+
"Refer to this documentation on using custom VPC: "
602+
"https://sagemaker.readthedocs.io/en/v2.24.0/overview.html"
603+
"#secure-training-and-inference-with-vpc"
604+
)
605+
logger.error("Check troubleshooting guide for common errors: %s", troubleshooting)
606+
600607
raise RuntimeError(
601608
"When setting up custom VPC, both subnets and security_group_ids must be set"
602609
)
603610

604611
if self.instance_type in ("local", "local_gpu"):
605612
if self.instance_type == "local_gpu" and self.instance_count > 1:
606-
raise RuntimeError("Distributed Training in Local GPU is not supported")
613+
raise RuntimeError(
614+
"Distributed Training in Local GPU is not supported."
615+
" Set instance_count to 1."
616+
)
607617
self.sagemaker_session = sagemaker_session or LocalSession()
608618
if not isinstance(self.sagemaker_session, sagemaker.local.LocalSession):
609619
raise RuntimeError(
610620
"instance_type local or local_gpu is only supported with an"
611-
"instance of LocalSession"
621+
"instance of LocalSession. More details on local mode: "
622+
"https://sagemaker.readthedocs.io/en/stable/overview.html#local-mode"
612623
)
613624
else:
614625
self.sagemaker_session = sagemaker_session or Session()
@@ -631,7 +642,11 @@ def __init__(
631642
and not is_pipeline_variable(output_path)
632643
and output_path.startswith("file://")
633644
):
634-
raise RuntimeError("file:// output paths are only supported in Local Mode")
645+
raise RuntimeError(
646+
"The 'file://' output paths are only supported when using Local Mode. "
647+
"To resolve this issue, ensure you're running in Local Mode with a LocalSession, "
648+
"or use an 's3://' output path for jobs running on SageMaker instances."
649+
)
635650
self.output_path = output_path
636651
self.latest_training_job = None
637652
self.jobs = []
@@ -646,7 +661,12 @@ def __init__(
646661
# Now we marked that as Optional because we can fetch it from SageMakerConfig
647662
# Because of marking that parameter as optional, we should validate if it is None, even
648663
# after fetching the config.
649-
raise ValueError("An AWS IAM role is required to create an estimator.")
664+
raise ValueError(
665+
"An AWS IAM role is required to create an estimator. "
666+
"Please provide a valid `role` argument with the ARN of an IAM role"
667+
" that has the necessary SageMaker permissions."
668+
)
669+
650670
self.output_kms_key = resolve_value_from_config(
651671
output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
652672
)
@@ -1855,6 +1875,8 @@ def model_data(self):
18551875
if compression_type not in {"GZIP", "NONE"}:
18561876
raise ValueError(
18571877
f'Unrecognized training job output data compression type "{compression_type}"'
1878+
'. Please specify either "GZIP" or "NONE" as valid options for '
1879+
"the compression type."
18581880
)
18591881
# model data is in uncompressed form NOTE SageMaker Hosting mandates presence of
18601882
# trailing forward slash in S3 model data URI, so append one if necessary.

src/sagemaker/hyperparameters.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
2222
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23-
from sagemaker.jumpstart.enums import HyperparameterValidationMode
23+
from sagemaker.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType
2424
from sagemaker.jumpstart.validators import validate_hyperparameters
2525
from sagemaker.session import Session
2626

@@ -38,6 +38,7 @@ def retrieve_default(
3838
tolerate_deprecated_model: bool = False,
3939
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4040
config_name: Optional[str] = None,
41+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4142
) -> Dict[str, str]:
4243
"""Retrieves the default training hyperparameters for the model matching the given arguments.
4344
@@ -71,6 +72,8 @@ def retrieve_default(
7172
specified, one is created using the default AWS configuration
7273
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
7374
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
75+
model_type (JumpStartModelType): The type of the model, can be open weights model
76+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
7477
Returns:
7578
dict: The hyperparameters to use for the model.
7679
@@ -93,6 +96,7 @@ def retrieve_default(
9396
tolerate_deprecated_model=tolerate_deprecated_model,
9497
sagemaker_session=sagemaker_session,
9598
config_name=config_name,
99+
model_type=model_type,
96100
)
97101

98102

src/sagemaker/image_uris.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from sagemaker import utils
2424
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
25+
from sagemaker.jumpstart.enums import JumpStartModelType
2526
from sagemaker.jumpstart.utils import is_jumpstart_model_input
2627
from sagemaker.spark import defaults
2728
from sagemaker.jumpstart import artifacts
@@ -72,6 +73,7 @@ def retrieve(
7273
serverless_inference_config=None,
7374
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
7475
config_name=None,
76+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
7577
) -> str:
7678
"""Retrieves the ECR URI for the Docker image matching the given arguments.
7779
@@ -128,6 +130,8 @@ def retrieve(
128130
specified, one is created using the default AWS configuration
129131
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
130132
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
133+
model_type (JumpStartModelType): The type of the model, can be open weights model
134+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
131135
132136
Returns:
133137
str: The ECR URI for the corresponding SageMaker Docker image.
@@ -169,6 +173,7 @@ def retrieve(
169173
tolerate_deprecated_model,
170174
sagemaker_session=sagemaker_session,
171175
config_name=config_name,
176+
model_type=model_type,
172177
)
173178

174179
if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
2020
)
2121
from sagemaker.jumpstart.enums import (
22+
JumpStartModelType,
2223
JumpStartScriptScope,
2324
)
2425
from sagemaker.jumpstart.utils import (
@@ -41,6 +42,7 @@ def _retrieve_default_environment_variables(
4142
instance_type: Optional[str] = None,
4243
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
4344
config_name: Optional[str] = None,
45+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4446
) -> Dict[str, str]:
4547
"""Retrieves the inference environment variables for the model matching the given arguments.
4648
@@ -73,6 +75,8 @@ def _retrieve_default_environment_variables(
7375
script (JumpStartScriptScope): The JumpStart script for which to retrieve
7476
environment variables.
7577
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
78+
model_type (JumpStartModelType): The type of the model, can be open weights model
79+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
7680
Returns:
7781
dict: the inference environment variables to use for the model.
7882
"""
@@ -91,6 +95,7 @@ def _retrieve_default_environment_variables(
9195
tolerate_deprecated_model=tolerate_deprecated_model,
9296
sagemaker_session=sagemaker_session,
9397
config_name=config_name,
98+
model_type=model_type,
9499
)
95100

96101
default_environment_variables: Dict[str, str] = {}
@@ -130,6 +135,7 @@ def _retrieve_default_environment_variables(
130135
sagemaker_session=sagemaker_session,
131136
instance_type=instance_type,
132137
config_name=config_name,
138+
model_type=model_type,
133139
)
134140
)
135141

@@ -178,6 +184,7 @@ def _retrieve_gated_model_uri_env_var_value(
178184
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
179185
instance_type: Optional[str] = None,
180186
config_name: Optional[str] = None,
187+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
181188
) -> Optional[str]:
182189
"""Retrieves the gated model env var URI matching the given arguments.
183190
@@ -204,7 +211,8 @@ def _retrieve_gated_model_uri_env_var_value(
204211
instance_type (str): An instance type to optionally supply in order to get
205212
environment variables specific for the instance type.
206213
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
207-
214+
model_type (JumpStartModelType): The type of the model, can be open weights model
215+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
208216
Returns:
209217
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
210218
have gated training artifacts.
@@ -227,6 +235,7 @@ def _retrieve_gated_model_uri_env_var_value(
227235
tolerate_deprecated_model=tolerate_deprecated_model,
228236
sagemaker_session=sagemaker_session,
229237
config_name=config_name,
238+
model_type=model_type,
230239
)
231240

232241
s3_key: Optional[str] = (

src/sagemaker/jumpstart/artifacts/hyperparameters.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1818
)
1919
from sagemaker.jumpstart.enums import (
20+
JumpStartModelType,
2021
JumpStartScriptScope,
2122
VariableScope,
2223
)
@@ -38,6 +39,7 @@ def _retrieve_default_hyperparameters(
3839
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3940
instance_type: Optional[str] = None,
4041
config_name: Optional[str] = None,
42+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4143
):
4244
"""Retrieves the training hyperparameters for the model matching the given arguments.
4345
@@ -71,6 +73,8 @@ def _retrieve_default_hyperparameters(
7173
instance_type (str): An instance type to optionally supply in order to get hyperparameters
7274
specific for the instance type.
7375
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
76+
model_type (JumpStartModelType): The type of the model, can be open weights model
77+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
7478
Returns:
7579
dict: the hyperparameters to use for the model.
7680
"""
@@ -89,6 +93,7 @@ def _retrieve_default_hyperparameters(
8993
tolerate_deprecated_model=tolerate_deprecated_model,
9094
sagemaker_session=sagemaker_session,
9195
config_name=config_name,
96+
model_type=model_type,
9297
)
9398

9499
default_hyperparameters: Dict[str, str] = {}

src/sagemaker/jumpstart/artifacts/image_uris.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2020
)
2121
from sagemaker.jumpstart.enums import (
22+
JumpStartModelType,
2223
JumpStartScriptScope,
2324
ModelFramework,
2425
)
@@ -48,6 +49,7 @@ def _retrieve_image_uri(
4849
tolerate_deprecated_model: bool = False,
4950
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
5051
config_name: Optional[str] = None,
52+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
5153
):
5254
"""Retrieves the container image URI for JumpStart models.
5355
@@ -100,6 +102,8 @@ def _retrieve_image_uri(
100102
specified, one is created using the default AWS configuration
101103
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
102104
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
105+
model_type (JumpStartModelType): The type of the model, can be open weights model
106+
or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
103107
Returns:
104108
str: the ECR URI for the corresponding SageMaker Docker image.
105109
@@ -123,6 +127,7 @@ def _retrieve_image_uri(
123127
tolerate_deprecated_model=tolerate_deprecated_model,
124128
sagemaker_session=sagemaker_session,
125129
config_name=config_name,
130+
model_type=model_type,
126131
)
127132

128133
if image_scope == JumpStartScriptScope.INFERENCE:

0 commit comments

Comments
 (0)