Skip to content

Commit 32ec377

Browse files
authored
Merge branch 'master' into sagemaker-mlflow-extras
2 parents 6aa6822 + da86520 commit 32ec377

File tree

6 files changed

+74
-3
lines changed

6 files changed

+74
-3
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ _Put an `x` in the boxes that apply. You can also fill these out after creating
2222
- [ ] I have added unit and/or integration tests as appropriate to ensure backward compatibility of the changes
2323
- [ ] I have checked that my tests are not configured for a specific region or account (if appropriate)
2424
- [ ] I have used [`unique_name_from_base`](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/utils.py#L77) to create resource names in integ tests (if appropriate)
25+
- [ ] If adding any dependency in requirements.txt files, I have spell checked and ensured they exist in PyPi
2526

2627
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.

src/sagemaker/pipeline.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from sagemaker.drift_check_baselines import DriftCheckBaselines
2828
from sagemaker.metadata_properties import MetadataProperties
29+
from sagemaker.model import ModelPackage
2930
from sagemaker.model_card import (
3031
ModelCard,
3132
ModelPackageModelCard,
@@ -470,7 +471,18 @@ def register(
470471
model_card=model_card,
471472
)
472473

473-
self.sagemaker_session.create_model_package_from_containers(**model_pkg_args)
474+
model_package = self.sagemaker_session.create_model_package_from_containers(
475+
**model_pkg_args
476+
)
477+
478+
if model_package is not None and "ModelPackageArn" in model_package:
479+
return ModelPackage(
480+
role=self.role,
481+
model_package_arn=model_package.get("ModelPackageArn"),
482+
sagemaker_session=self.sagemaker_session,
483+
predictor_cls=self.predictor_cls,
484+
)
485+
return None
474486

475487
def transformer(
476488
self,

src/sagemaker/serve/utils/conda_in_process.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ dependencies:
3737
- botocore>=1.29.114
3838
- cachetools>=5.3.0
3939
- certifi==2022.12.7
40-
- harset-normalizer>=3.1.0
40+
- charset-normalizer>=3.1.0
4141
- click>=8.1.3
4242
- cloudpickle>=2.2.1
4343
- colorama>=0.4.4

src/sagemaker/serve/utils/in_process_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ blinker>=1.6.2
55
botocore>=1.29.114
66
cachetools>=5.3.0
77
certifi==2024.7.4
8-
harset-normalizer>=3.1.0
8+
charset-normalizer>=3.1.0
99
click>=8.1.3
1010
cloudpickle>=2.2.1
1111
colorama>=0.4.4

tests/integ/test_inference_pipeline.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,40 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
150150
assert "Could not find model" in str(exception.value)
151151

152152

153+
@pytest.mark.release
154+
def test_inference_pipeline_model_register(sagemaker_session):
155+
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
156+
endpoint_name = unique_name_from_base("test-inference-pipeline-deploy")
157+
sparkml_model_data = sagemaker_session.upload_data(
158+
path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"),
159+
key_prefix="integ-test-data/sparkml/model",
160+
)
161+
162+
sparkml_model = SparkMLModel(
163+
model_data=sparkml_model_data,
164+
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
165+
sagemaker_session=sagemaker_session,
166+
)
167+
168+
model = PipelineModel(
169+
models=[sparkml_model],
170+
role="SageMakerRole",
171+
sagemaker_session=sagemaker_session,
172+
name=endpoint_name,
173+
)
174+
model_package_group_name = unique_name_from_base("pipeline-model-package")
175+
model_package = model.register(model_package_group_name=model_package_group_name)
176+
assert model_package.model_package_arn is not None
177+
178+
sagemaker_session.sagemaker_client.delete_model_package(
179+
ModelPackageName=model_package.model_package_arn
180+
)
181+
182+
sagemaker_session.sagemaker_client.delete_model_package_group(
183+
ModelPackageGroupName=model_package_group_name
184+
)
185+
186+
153187
@pytest.mark.slow_test
154188
@pytest.mark.flaky(reruns=5, reruns_delay=2)
155189
def test_inference_pipeline_model_deploy_and_update_endpoint(

tests/unit/test_pipeline_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,27 @@ def test_network_isolation(tfo, time, sagemaker_session):
420420
vpc_config=None,
421421
enable_network_isolation=True,
422422
)
423+
424+
425+
def test_pipeline_model_register(sagemaker_session):
426+
sagemaker_session.create_model_package_from_containers = Mock(
427+
name="create_model_package_from_containers",
428+
return_value={
429+
"ModelPackageArn": "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
430+
},
431+
)
432+
framework_model = DummyFrameworkModel(sagemaker_session)
433+
sparkml_model = SparkMLModel(
434+
model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session
435+
)
436+
model = PipelineModel(
437+
models=[framework_model, sparkml_model],
438+
role=ROLE,
439+
sagemaker_session=sagemaker_session,
440+
enable_network_isolation=True,
441+
)
442+
model_package = model.register()
443+
assert (
444+
model_package.model_package_arn
445+
== "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
446+
)

0 commit comments

Comments
 (0)