File tree Expand file tree Collapse file tree 4 files changed +12
-4
lines changed
tests/unit/sagemaker/jumpstart Expand file tree Collapse file tree 4 files changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -608,6 +608,7 @@ def _validate_model_id_and_get_type_hook():
608
608
self .role = estimator_init_kwargs .role
609
609
self .sagemaker_session = estimator_init_kwargs .sagemaker_session
610
610
self ._enable_network_isolation = estimator_init_kwargs .enable_network_isolation
611
+ super (JumpStartEstimator , self ).__init__ (** estimator_init_kwargs .to_kwargs_dict ())
611
612
self ._internal_config = JumpStartModelInternalConfig (
612
613
specs = verify_model_region_and_return_specs (
613
614
region = self .region ,
@@ -617,9 +618,10 @@ def _validate_model_id_and_get_type_hook():
617
618
model_type = self .model_type ,
618
619
scope = JumpStartScriptScope .TRAINING ,
619
620
sagemaker_session = self .sagemaker_session ,
621
+ tolerate_vulnerable_model = self .tolerate_vulnerable_model ,
622
+ tolerate_deprecated_model = self .tolerate_deprecated_model ,
620
623
)
621
624
)
622
- super (JumpStartEstimator , self ).__init__ (** estimator_init_kwargs .to_kwargs_dict ())
623
625
624
626
def fit (
625
627
self ,
Original file line number Diff line number Diff line change @@ -361,6 +361,9 @@ def _validate_model_id_and_type():
361
361
self .log_subscription_warning ()
362
362
363
363
model_init_kwargs_dict = model_init_kwargs .to_kwargs_dict ()
364
+ super (JumpStartModel , self ).__init__ (** model_init_kwargs_dict )
365
+
366
+ self .model_package_arn = model_init_kwargs .model_package_arn
364
367
self ._internal_config = JumpStartModelInternalConfig (
365
368
specs = verify_model_region_and_return_specs (
366
369
region = self .region ,
@@ -370,11 +373,10 @@ def _validate_model_id_and_type():
370
373
model_type = self .model_type ,
371
374
scope = JumpStartScriptScope .INFERENCE ,
372
375
sagemaker_session = self .sagemaker_session ,
376
+ tolerate_vulnerable_model = self .tolerate_vulnerable_model ,
377
+ tolerate_deprecated_model = self .tolerate_deprecated_model ,
373
378
)
374
379
)
375
- super (JumpStartModel , self ).__init__ (** model_init_kwargs_dict )
376
-
377
- self .model_package_arn = model_init_kwargs .model_package_arn
378
380
379
381
def log_subscription_warning (self ) -> None :
380
382
"""Log message prompting the customer to subscribe to the proprietary model."""
Original file line number Diff line number Diff line change @@ -1171,8 +1171,10 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
1171
1171
@mock .patch ("sagemaker.jumpstart.estimator.get_init_kwargs" )
1172
1172
@mock .patch ("sagemaker.jumpstart.estimator.Estimator.__init__" )
1173
1173
@mock .patch ("sagemaker.jumpstart.estimator.validate_model_id_and_get_type" )
1174
+ @mock .patch ("sagemaker.jumpstart.estimator.verify_model_region_and_return_specs" )
1174
1175
def test_validate_model_id_and_get_type (
1175
1176
self ,
1177
+ verify_model_region_and_return_specs : mock .Mock ,
1176
1178
mock_validate_model_id_and_get_type : mock .Mock ,
1177
1179
mock_init : mock .Mock ,
1178
1180
mock_get_init_kwargs : mock .Mock ,
Original file line number Diff line number Diff line change @@ -758,8 +758,10 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
758
758
@mock .patch ("sagemaker.jumpstart.model.get_init_kwargs" )
759
759
@mock .patch ("sagemaker.jumpstart.model.Model.__init__" )
760
760
@mock .patch ("sagemaker.jumpstart.model.validate_model_id_and_get_type" )
761
+ @mock .patch ("sagemaker.jumpstart.model.verify_model_region_and_return_specs" )
761
762
def test_validate_model_id_and_get_type (
762
763
self ,
764
+ verify_model_region_and_return_specs : mock .Mock ,
763
765
mock_validate_model_id_and_get_type : mock .Mock ,
764
766
mock_init : mock .Mock ,
765
767
mock_get_init_kwargs : mock .Mock ,
You can’t perform that action at this time.
0 commit comments