@@ -170,6 +170,42 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
170
170
assert response is not None
171
171
172
172
173
+ def test_jumpstart_hub_gated_estimator_with_eula_env_var (setup , add_model_references ):
174
+
175
+ model_id , model_version = "meta-textgeneration-llama-2-7b" , "*"
176
+
177
+ estimator = JumpStartEstimator (
178
+ model_id = model_id ,
179
+ hub_name = os .environ [ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME ],
180
+ environment = {
181
+ "accept_eula" : "true" ,
182
+ },
183
+ tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
184
+ )
185
+
186
+ estimator .fit (
187
+ inputs = {
188
+ "training" : f"s3://{ get_jumpstart_content_bucket (JUMPSTART_DEFAULT_REGION_NAME )} /"
189
+ f"{ get_training_dataset_for_model_and_version (model_id , model_version )} " ,
190
+ },
191
+ )
192
+
193
+ predictor = estimator .deploy (
194
+ tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
195
+ role = get_sm_session ().get_caller_identity_arn (),
196
+ sagemaker_session = get_sm_session (),
197
+ )
198
+
199
+ payload = {
200
+ "inputs" : "some-payload" ,
201
+ "parameters" : {"max_new_tokens" : 256 , "top_p" : 0.9 , "temperature" : 0.6 },
202
+ }
203
+
204
+ response = predictor .predict (payload , custom_attributes = "accept_eula=true" )
205
+
206
+ assert response is not None
207
+
208
+
173
209
def test_jumpstart_hub_gated_estimator_without_eula (setup , add_model_references ):
174
210
175
211
model_id , model_version = "meta-textgeneration-llama-2-7b" , "*"
0 commit comments