@@ -60,17 +60,12 @@ def add_model_references():
6060
6161
6262def test_jumpstart_hub_estimator (setup , add_model_references ):
63-
6463 model_id , model_version = "huggingface-spc-bert-base-cased" , "*"
6564
66- sagemaker_session = get_sm_session ()
67-
6865 estimator = JumpStartEstimator (
6966 model_id = model_id ,
70- role = sagemaker_session .get_caller_identity_arn (),
71- sagemaker_session = sagemaker_session ,
72- tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
7367 hub_name = os .environ [ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME ],
68+ tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
7469 )
7570
7671 estimator .fit (
@@ -85,22 +80,20 @@ def test_jumpstart_hub_estimator(setup, add_model_references):
8580 training_job_name = estimator .latest_training_job .name ,
8681 model_id = model_id ,
8782 model_version = model_version ,
88- sagemaker_session = get_sm_session (),
8983 )
9084
9185 # uses ml.p3.2xlarge instance
9286 predictor = estimator .deploy (
9387 tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
94- role = get_sm_session ().get_caller_identity_arn (),
95- sagemaker_session = get_sm_session (),
9688 )
9789
9890 response = predictor .predict (["hello" , "world" ])
9991
10092 assert response is not None
10193
10294
103- def test_jumpstart_hub_estimator_with_default_session (setup , add_model_references ):
95+ def test_jumpstart_hub_estimator_with_session (setup , add_model_references ):
96+
10497 model_id , model_version = "huggingface-spc-bert-base-cased" , "*"
10598
10699 sagemaker_session = get_sm_session ()
@@ -125,12 +118,14 @@ def test_jumpstart_hub_estimator_with_default_session(setup, add_model_reference
125118 training_job_name = estimator .latest_training_job .name ,
126119 model_id = model_id ,
127120 model_version = model_version ,
121+ sagemaker_session = get_sm_session (),
128122 )
129123
130124 # uses ml.p3.2xlarge instance
131125 predictor = estimator .deploy (
132126 tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
133127 role = get_sm_session ().get_caller_identity_arn (),
128+ sagemaker_session = get_sm_session (),
134129 )
135130
136131 response = predictor .predict (["hello" , "world" ])
@@ -144,9 +139,8 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
144139
145140 estimator = JumpStartEstimator (
146141 model_id = model_id ,
147- role = get_sm_session ().get_caller_identity_arn (),
148- sagemaker_session = get_sm_session (),
149142 hub_name = os .environ [ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME ],
143+ tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
150144 )
151145
152146 estimator .fit (
@@ -161,14 +155,11 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
161155 training_job_name = estimator .latest_training_job .name ,
162156 model_id = model_id ,
163157 model_version = model_version ,
164- sagemaker_session = get_sm_session (),
165158 )
166159
167160 # uses ml.p3.2xlarge instance
168161 predictor = estimator .deploy (
169162 tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
170- role = get_sm_session ().get_caller_identity_arn (),
171- sagemaker_session = get_sm_session (),
172163 )
173164
174165 response = predictor .predict (["hello" , "world" ])
@@ -182,9 +173,8 @@ def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references)
182173
183174 estimator = JumpStartEstimator (
184175 model_id = model_id ,
185- role = get_sm_session ().get_caller_identity_arn (),
186- sagemaker_session = get_sm_session (),
187176 hub_name = os .environ [ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME ],
177+ tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
188178 )
189179 with pytest .raises (Exception ):
190180 estimator .fit (
0 commit comments