Skip to content

Commit bd02b60

Browse files
committed
fix model_uri usage issue
1 parent 5624346 commit bd02b60

File tree

6 files changed

+25
-12
lines changed

6 files changed

+25
-12
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2511,7 +2511,7 @@ def start_new(cls, estimator, inputs, experiment_config):
25112511
train_args = cls._get_train_args(estimator, inputs, experiment_config)
25122512

25132513
logger.debug("Train args after processing defaults: %s", train_args)
2514-
print("rohan debug: ", train_args)
2514+
25152515
estimator.sagemaker_session.train(**train_args)
25162516

25172517
return cls(estimator.sagemaker_session, estimator._current_job_name)

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
JUMPSTART_LOGGER,
5757
TRAINING_ENTRY_POINT_SCRIPT_NAME,
5858
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
59+
JUMPSTART_MODEL_HUB_NAME,
5960
)
6061
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
6162
from sagemaker.jumpstart.factory import model
@@ -630,8 +631,13 @@ def _add_model_reference_arn_to_kwargs(
630631

631632
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
632633
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
633-
634-
if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)):
634+
# hub_arn is by default None unless the user specifies the hub_name
635+
# If no hub_name is specified, it is assumed the public hub
636+
is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False
637+
if (
638+
_model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs))
639+
or is_private_hub
640+
):
635641
default_model_uri = model_uris.retrieve(
636642
model_scope=JumpStartScriptScope.TRAINING,
637643
instance_type=kwargs.instance_type,

src/sagemaker/jumpstart/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1940,6 +1940,11 @@ def use_inference_script_uri(self) -> bool:
19401940

19411941
def use_training_model_artifact(self) -> bool:
19421942
"""Returns True if the model should use a model uri when kicking off training job."""
1943+
# gated model never uses training model artifact
1944+
if self.gated_bucket:
1945+
return False
1946+
1947+
# otherwise, return true is a training model package is not set
19431948
return len(self.training_model_package_artifact_uris or {}) == 0
19441949

19451950
def is_gated_model(self) -> bool:

tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,18 +151,18 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
151151
},
152152
)
153153

154-
estimator = JumpStartEstimator.attach(
155-
training_job_name=estimator.latest_training_job.name,
156-
model_id=model_id,
157-
model_version=model_version,
158-
)
159-
160-
# uses ml.p3.2xlarge instance
161154
predictor = estimator.deploy(
162155
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
156+
role=get_sm_session().get_caller_identity_arn(),
157+
sagemaker_session=get_sm_session(),
163158
)
164159

165-
response = predictor.predict(["hello", "world"])
160+
payload = {
161+
"inputs": "some-payload",
162+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
163+
}
164+
165+
response = predictor.predict(payload, custom_attributes="accept_eula=true")
166166

167167
assert response is not None
168168

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,6 @@ def test_gated_model_non_model_package_s3_uri(
688688
instance_count=1,
689689
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pyt"
690690
"orch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04",
691-
model_uri='s3://jumpstart-private-cache-prod-us-west-2/some/dummy/key',
692691
source_dir="s3://jumpstart-cache-prod-us-west-2/source-d"
693692
"irectory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz",
694693
entry_point="transfer_learning.py",

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,9 @@ def test_jumpstart_model_header():
332332
def test_use_training_model_artifact():
333333
specs1 = JumpStartModelSpecs(BASE_SPEC)
334334
assert specs1.use_training_model_artifact()
335+
specs1.gated_bucket = True
336+
assert not specs1.use_training_model_artifact()
337+
specs1.gated_bucket = False
335338
specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"}
336339
assert not specs1.use_training_model_artifact()
337340

0 commit comments

Comments
 (0)