Skip to content

Commit d74bd4e

Browse files
committed
fix: tests
1 parent 6e38a7a commit d74bd4e

File tree

4 files changed

+12
-4
lines changed

4 files changed

+12
-4
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ def _validate_model_id_and_get_type_hook():
608608
self.role = estimator_init_kwargs.role
609609
self.sagemaker_session = estimator_init_kwargs.sagemaker_session
610610
self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation
611+
super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict())
611612
self._internal_config = JumpStartModelInternalConfig(
612613
specs=verify_model_region_and_return_specs(
613614
region=self.region,
@@ -617,9 +618,10 @@ def _validate_model_id_and_get_type_hook():
617618
model_type=self.model_type,
618619
scope=JumpStartScriptScope.TRAINING,
619620
sagemaker_session=self.sagemaker_session,
621+
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
622+
tolerate_deprecated_model=self.tolerate_deprecated_model,
620623
)
621624
)
622-
super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict())
623625

624626
def fit(
625627
self,

src/sagemaker/jumpstart/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,9 @@ def _validate_model_id_and_type():
361361
self.log_subscription_warning()
362362

363363
model_init_kwargs_dict = model_init_kwargs.to_kwargs_dict()
364+
super(JumpStartModel, self).__init__(**model_init_kwargs_dict)
365+
366+
self.model_package_arn = model_init_kwargs.model_package_arn
364367
self._internal_config = JumpStartModelInternalConfig(
365368
specs=verify_model_region_and_return_specs(
366369
region=self.region,
@@ -370,11 +373,10 @@ def _validate_model_id_and_type():
370373
model_type=self.model_type,
371374
scope=JumpStartScriptScope.INFERENCE,
372375
sagemaker_session=self.sagemaker_session,
376+
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
377+
tolerate_deprecated_model=self.tolerate_deprecated_model,
373378
)
374379
)
375-
super(JumpStartModel, self).__init__(**model_init_kwargs_dict)
376-
377-
self.model_package_arn = model_init_kwargs.model_package_arn
378380

379381
def log_subscription_warning(self) -> None:
380382
"""Log message prompting the customer to subscribe to the proprietary model."""

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,8 +1171,10 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
11711171
@mock.patch("sagemaker.jumpstart.estimator.get_init_kwargs")
11721172
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
11731173
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
1174+
@mock.patch("sagemaker.jumpstart.estimator.verify_model_region_and_return_specs")
11741175
def test_validate_model_id_and_get_type(
11751176
self,
1177+
verify_model_region_and_return_specs: mock.Mock,
11761178
mock_validate_model_id_and_get_type: mock.Mock,
11771179
mock_init: mock.Mock,
11781180
mock_get_init_kwargs: mock.Mock,

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,8 +758,10 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
758758
@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
759759
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
760760
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
761+
@mock.patch("sagemaker.jumpstart.model.verify_model_region_and_return_specs")
761762
def test_validate_model_id_and_get_type(
762763
self,
764+
verify_model_region_and_return_specs: mock.Mock,
763765
mock_validate_model_id_and_get_type: mock.Mock,
764766
mock_init: mock.Mock,
765767
mock_get_init_kwargs: mock.Mock,

0 commit comments

Comments
 (0)