Skip to content

Commit c6578e7

Browse files
committed
fix: unit test cases
1 parent 27b29de commit c6578e7

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

src/sagemaker/model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -549,11 +549,7 @@ def register(
549549
model_package_group_name = utils.base_name_from_image(
550550
self.image_uri, default_base_name=ModelPackage.__name__
551551
)
552-
if (
553-
model_package_group_name is not None
554-
and self.model_type is not None
555-
and self.model_type is not JumpStartModelType.PROPRIETARY
556-
):
552+
if model_package_group_name is not None and model_type is not JumpStartModelType.PROPRIETARY:
557553
container_def = self.prepare_container_def(accept_eula=accept_eula)
558554
container_def = update_container_with_inference_params(
559555
framework=framework,

tests/unit/test_estimator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4402,7 +4402,10 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
44024402
framework = "TENSORFLOW"
44034403
framework_version = "2.9"
44044404
nearest_model_name = "resnet50"
4405-
4405+
model_card = {
4406+
'ModelCardStatus': ModelCardStatusEnum.DRAFT,
4407+
'ModelCardContent': '{}'
4408+
}
44064409
estimator.register(
44074410
content_types=content_types,
44084411
response_types=response_types,
@@ -4425,6 +4428,7 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
44254428
"marketplace_cert": False,
44264429
"sample_payload_url": sample_payload_url,
44274430
"task": task,
4431+
"model_card": model_card,
44284432
}
44294433
sagemaker_session.create_model_package_from_containers.assert_called_with(
44304434
**expected_create_model_package_request
@@ -4454,6 +4458,10 @@ def test_register_inference_image(sagemaker_session):
44544458
framework = "TENSORFLOW"
44554459
framework_version = "2.9"
44564460
nearest_model_name = "resnet50"
4461+
model_card = {
4462+
'ModelCardStatus': ModelCardStatusEnum.DRAFT,
4463+
'ModelCardContent': '{}'
4464+
}
44574465

44584466
estimator.register(
44594467
content_types=content_types,
@@ -4480,6 +4488,7 @@ def test_register_inference_image(sagemaker_session):
44804488
"marketplace_cert": False,
44814489
"sample_payload_url": sample_payload_url,
44824490
"task": task,
4491+
"model_card": model_card,
44834492
}
44844493
sagemaker_session.create_model_package_from_containers.assert_called_with(
44854494
**expected_create_model_package_request

0 commit comments

Comments
 (0)