Skip to content
Merged

Master #5112

Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sagemaker.jumpstart.hub.utils import (
construct_hub_model_arn_from_inputs,
construct_hub_model_reference_arn_from_inputs,
generate_hub_arn_for_init_kwargs,
)
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.session import Session
Expand Down Expand Up @@ -291,6 +292,10 @@ def get_model_specs(
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
if hub_arn:
try:
hub_arn = generate_hub_arn_for_init_kwargs(
hub_name=hub_arn, region=region, session=sagemaker_session
)

hub_model_arn = construct_hub_model_reference_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def fit(
accept the end-user license agreement (EULA) that some
models require. (Default: None).
"""
self.model_access_config = get_model_access_config(accept_eula)
self.model_access_config = get_model_access_config(accept_eula, self.environment)
self.hub_access_config = get_hub_access_config(
hub_content_arn=self.init_kwargs.get("model_reference_arn", None)
)
Expand Down
24 changes: 19 additions & 5 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,8 +1641,14 @@ def remove_env_var_from_estimator_kwargs_if_accept_eula_present(
init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
"""
if accept_eula is not None and init_kwargs["environment"]:
del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY]
if accept_eula is not None and init_kwargs.get("environment") is not None:
if (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can do init_kwargs["environment"].pop(constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, None)

constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
in init_kwargs["environment"]
):
del init_kwargs["environment"][
constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY
]


def get_hub_access_config(hub_content_arn: Optional[str]):
Expand All @@ -1659,16 +1665,24 @@ def get_hub_access_config(hub_content_arn: Optional[str]):
return hub_access_config


def get_model_access_config(accept_eula: Optional[bool]):
def get_model_access_config(accept_eula: Optional[bool], environment: Optional[dict]):
"""Get access configs

Args:
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
"""
env_var_eula = environment.get("accept_eula")
if env_var_eula and accept_eula is not None:
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd log a warn here instead

"Cannot pass in both accept_eula and environment variables. "
"Please remove the environment variable and pass in the accept_eula parameter."
)

model_access_config = None
if env_var_eula is not None:
model_access_config = {"AcceptEula": True if env_var_eula == "true" else False}
if accept_eula is not None:
model_access_config = {"AcceptEula": accept_eula}
else:
model_access_config = None

return model_access_config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,42 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
assert response is not None


def test_jumpstart_hub_gated_estimator_with_eula_env_var(setup, add_model_references):

model_id, model_version = "meta-textgeneration-llama-2-7b", "*"

estimator = JumpStartEstimator(
model_id=model_id,
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
environment={
"accept_eula": "true",
},
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
)

estimator.fit(
inputs={
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
},
)

predictor = estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)

payload = {
"inputs": "some-payload",
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
}

response = predictor.predict(payload, custom_attributes="accept_eula=true")

assert response is not None


def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references):

model_id, model_version = "meta-textgeneration-llama-2-7b", "*"
Expand Down
Loading