Skip to content

Commit 13dd058

Browse files
committed
fix jumpstart curated hub bugs
1 parent 1faecb4 commit 13dd058

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,42 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
170170
assert response is not None
171171

172172

173+
def test_jumpstart_hub_gated_estimator_with_eula_env_var(setup, add_model_references):
174+
175+
model_id, model_version = "meta-textgeneration-llama-2-7b", "*"
176+
177+
estimator = JumpStartEstimator(
178+
model_id=model_id,
179+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
180+
environment={
181+
"accept_eula": "true",
182+
},
183+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
184+
)
185+
186+
estimator.fit(
187+
inputs={
188+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
189+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
190+
},
191+
)
192+
193+
predictor = estimator.deploy(
194+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
195+
role=get_sm_session().get_caller_identity_arn(),
196+
sagemaker_session=get_sm_session(),
197+
)
198+
199+
payload = {
200+
"inputs": "some-payload",
201+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
202+
}
203+
204+
response = predictor.predict(payload, custom_attributes="accept_eula=true")
205+
206+
assert response is not None
207+
208+
173209
def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references):
174210

175211
model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

0 commit comments

Comments
 (0)