Skip to content

Commit e39d724

Browse files
authored
Merge branch 'master' into sagemaker-mlflow-extras
2 parents a126c5d + 56ecc76 commit e39d724

File tree

7 files changed

+100
-9
lines changed

7 files changed

+100
-9
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ _Put an `x` in the boxes that apply. You can also fill these out after creating
2222
- [ ] I have added unit and/or integration tests as appropriate to ensure backward compatibility of the changes
2323
- [ ] I have checked that my tests are not configured for a specific region or account (if appropriate)
2424
- [ ] I have used [`unique_name_from_base`](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/utils.py#L77) to create resource names in integ tests (if appropriate)
25+
- [ ] If adding any dependency in requirements.txt files, I have spell checked and ensured they exist in PyPi
2526

2627
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.

src/sagemaker/pipeline.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from sagemaker.drift_check_baselines import DriftCheckBaselines
2828
from sagemaker.metadata_properties import MetadataProperties
29+
from sagemaker.model import ModelPackage
2930
from sagemaker.model_card import (
3031
ModelCard,
3132
ModelPackageModelCard,
@@ -470,7 +471,18 @@ def register(
470471
model_card=model_card,
471472
)
472473

473-
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
474+
model_package = self.sagemaker_session.create_model_package_from_containers(
475+
**model_pkg_args
476+
)
477+
478+
if model_package is not None and "ModelPackageArn" in model_package:
479+
return ModelPackage(
480+
role=self.role,
481+
model_package_arn=model_package.get("ModelPackageArn"),
482+
sagemaker_session=self.sagemaker_session,
483+
predictor_cls=self.predictor_cls,
484+
)
485+
return None
474486

475487
def transformer(
476488
self,

src/sagemaker/serve/utils/conda_in_process.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ dependencies:
3737
- botocore>=1.29.114
3838
- cachetools>=5.3.0
3939
- certifi==2022.12.7
40-
- harset-normalizer>=3.1.0
40+
- charset-normalizer>=3.1.0
4141
- click>=8.1.3
4242
- cloudpickle>=2.2.1
4343
- colorama>=0.4.4

src/sagemaker/serve/utils/in_process_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ blinker>=1.6.2
55
botocore>=1.29.114
66
cachetools>=5.3.0
77
certifi==2024.7.4
8-
harset-normalizer>=3.1.0
8+
charset-normalizer>=3.1.0
99
click>=8.1.3
1010
cloudpickle>=2.2.1
1111
colorama>=0.4.4

tests/integ/test_inference_pipeline.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,40 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
150150
assert "Could not find model" in str(exception.value)
151151

152152

153+
@pytest.mark.release
154+
def test_inference_pipeline_model_register(sagemaker_session):
155+
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
156+
endpoint_name = unique_name_from_base("test-inference-pipeline-deploy")
157+
sparkml_model_data = sagemaker_session.upload_data(
158+
path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"),
159+
key_prefix="integ-test-data/sparkml/model",
160+
)
161+
162+
sparkml_model = SparkMLModel(
163+
model_data=sparkml_model_data,
164+
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
165+
sagemaker_session=sagemaker_session,
166+
)
167+
168+
model = PipelineModel(
169+
models=[sparkml_model],
170+
role="SageMakerRole",
171+
sagemaker_session=sagemaker_session,
172+
name=endpoint_name,
173+
)
174+
model_package_group_name = unique_name_from_base("pipeline-model-package")
175+
model_package = model.register(model_package_group_name=model_package_group_name)
176+
assert model_package.model_package_arn is not None
177+
178+
sagemaker_session.sagemaker_client.delete_model_package(
179+
ModelPackageName=model_package.model_package_arn
180+
)
181+
182+
sagemaker_session.sagemaker_client.delete_model_package_group(
183+
ModelPackageGroupName=model_package_group_name
184+
)
185+
186+
153187
@pytest.mark.slow_test
154188
@pytest.mark.flaky(reruns=5, reruns_delay=2)
155189
def test_inference_pipeline_model_deploy_and_update_endpoint(

tests/unit/sagemaker/image_uris/test_graviton.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,18 @@
3030
]
3131

3232

33-
def _test_graviton_framework_uris(framework, version, py_version, account, region):
33+
def _test_graviton_framework_uris(
34+
framework, version, py_version, account, region, container_version="ubuntu20.04-sagemaker"
35+
):
3436
for instance_type in GRAVITON_INSTANCE_TYPES:
3537
uri = image_uris.retrieve(framework, region, instance_type=instance_type, version=version)
3638
expected = _expected_graviton_framework_uri(
37-
framework, version, py_version, account, region=region
39+
framework,
40+
version,
41+
py_version,
42+
account,
43+
region=region,
44+
container_version=container_version,
3845
)
3946
assert expected == uri
4047

@@ -50,11 +57,21 @@ def test_graviton_framework_uris(load_config_and_file_name, scope):
5057
for version in VERSIONS:
5158
ACCOUNTS = config[scope]["versions"][version]["registries"]
5259
py_versions = config[scope]["versions"][version]["py_versions"]
60+
container_version = (
61+
config[scope]["versions"][version].get("container_version", {}).get("cpu", None)
62+
)
63+
if container_version:
64+
container_version = container_version + "-sagemaker"
5365
for py_version in py_versions:
5466
for region in ACCOUNTS.keys():
55-
_test_graviton_framework_uris(
56-
framework, version, py_version, ACCOUNTS[region], region
57-
)
67+
if container_version:
68+
_test_graviton_framework_uris(
69+
framework, version, py_version, ACCOUNTS[region], region, container_version
70+
)
71+
else:
72+
_test_graviton_framework_uris(
73+
framework, version, py_version, ACCOUNTS[region], region
74+
)
5875

5976

6077
def _test_graviton_unsupported_framework(framework, region, framework_version):
@@ -183,11 +200,14 @@ def test_graviton_sklearn_image_scope_specified_x86_instance(graviton_sklearn_un
183200
assert "Unsupported instance type: m5." in str(error)
184201

185202

186-
def _expected_graviton_framework_uri(framework, version, py_version, account, region):
203+
def _expected_graviton_framework_uri(
204+
framework, version, py_version, account, region, container_version
205+
):
187206
return expected_uris.graviton_framework_uri(
188207
"{}-inference-graviton".format(framework),
189208
fw_version=version,
190209
py_version=py_version,
191210
account=account,
192211
region=region,
212+
container_version=container_version,
193213
)

tests/unit/test_pipeline_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,27 @@ def test_network_isolation(tfo, time, sagemaker_session):
420420
vpc_config=None,
421421
enable_network_isolation=True,
422422
)
423+
424+
425+
def test_pipeline_model_register(sagemaker_session):
426+
sagemaker_session.create_model_package_from_containers = Mock(
427+
name="create_model_package_from_containers",
428+
return_value={
429+
"ModelPackageArn": "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
430+
},
431+
)
432+
framework_model = DummyFrameworkModel(sagemaker_session)
433+
sparkml_model = SparkMLModel(
434+
model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session
435+
)
436+
model = PipelineModel(
437+
models=[framework_model, sparkml_model],
438+
role=ROLE,
439+
sagemaker_session=sagemaker_session,
440+
enable_network_isolation=True,
441+
)
442+
model_package = model.register()
443+
assert (
444+
model_package.model_package_arn
445+
== "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
446+
)

0 commit comments

Comments
 (0)