Skip to content

Commit 622f706

Browse files
committed
fix unit test failure and fix bug around versioning
1 parent f8f0e14 commit 622f706

File tree

7 files changed

+27
-27
lines changed

7 files changed

+27
-27
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2511,6 +2511,7 @@ def start_new(cls, estimator, inputs, experiment_config):
25112511
train_args = cls._get_train_args(estimator, inputs, experiment_config)
25122512

25132513
logger.debug("Train args after processing defaults: %s", train_args)
2514+
print("rohan debug: ", train_args)
25142515
estimator.sagemaker_session.train(**train_args)
25152516

25162517
return cls(estimator.sagemaker_session, estimator._current_job_name)

src/sagemaker/jumpstart/hub/interfaces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
630630
if json_obj.get("ValidationSupported")
631631
else None
632632
)
633-
self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri")
634633
self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase")
635634
self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False))
636635
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
@@ -671,6 +670,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
671670
)
672671

673672
if self.training_supported:
673+
self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri")
674674
self.training_model_package_artifact_uri: Optional[str] = json_obj.get(
675675
"TrainingModelPackageArtifactUri"
676676
)

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
2323
from sagemaker.jumpstart import constants
2424
from packaging.specifiers import SpecifierSet, InvalidSpecifier
25+
from packaging import version
2526

2627
PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
2728

@@ -219,9 +220,7 @@ def get_hub_model_version(
219220
sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION
220221

221222
try:
222-
hub_content_summaries = sagemaker_session.list_hub_content_versions(
223-
hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type
224-
).get("HubContentSummaries")
223+
hub_content_summaries = _list_hub_content_versions_helper(hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type, sagemaker_session=sagemaker_session)
225224
except Exception as ex:
226225
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
227226

@@ -237,14 +236,26 @@ def get_hub_model_version(
237236
return marketplace_hub_content_version
238237
raise
239238

239+
def _list_hub_content_versions_helper(hub_name, hub_content_name, hub_content_type, sagemaker_session):
240+
all_hub_content_summaries = []
241+
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
242+
hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type
243+
)
244+
all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries"))
245+
while "NextToken" in list_hub_content_versions_response:
246+
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
247+
hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type, next_token=list_hub_content_versions_response["NextToken"]
248+
)
249+
all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries"))
250+
return all_hub_content_summaries
240251

241252
def _get_hub_model_version_for_open_weight_version(
242253
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
243254
) -> str:
244255
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]
245256

246257
if hub_model_version == "*" or hub_model_version is None:
247-
return str(max(available_model_versions))
258+
return str(max(version.parse(v) for v in available_model_versions))
248259

249260
try:
250261
spec = SpecifierSet(f"=={hub_model_version}")

src/sagemaker/jumpstart/types.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1940,11 +1940,6 @@ def use_inference_script_uri(self) -> bool:
19401940

19411941
def use_training_model_artifact(self) -> bool:
19421942
"""Returns True if the model should use a model uri when kicking off training job."""
1943-
# gated model never use training model artifact
1944-
if self.gated_bucket:
1945-
return False
1946-
1947-
# otherwise, return true is a training model package is not set
19481943
return len(self.training_model_package_artifact_uris or {}) == 0
19491944

19501945
def is_gated_model(self) -> bool:

tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,12 @@ def add_model_references():
6060

6161

6262
def test_jumpstart_hub_estimator(setup, add_model_references):
63-
6463
model_id, model_version = "huggingface-spc-bert-base-cased", "*"
6564

66-
sagemaker_session = get_sm_session()
67-
6865
estimator = JumpStartEstimator(
6966
model_id=model_id,
70-
role=sagemaker_session.get_caller_identity_arn(),
71-
sagemaker_session=sagemaker_session,
72-
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
7367
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
68+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
7469
)
7570

7671
estimator.fit(
@@ -85,22 +80,20 @@ def test_jumpstart_hub_estimator(setup, add_model_references):
8580
training_job_name=estimator.latest_training_job.name,
8681
model_id=model_id,
8782
model_version=model_version,
88-
sagemaker_session=get_sm_session(),
8983
)
9084

9185
# uses ml.p3.2xlarge instance
9286
predictor = estimator.deploy(
9387
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
94-
role=get_sm_session().get_caller_identity_arn(),
95-
sagemaker_session=get_sm_session(),
9688
)
9789

9890
response = predictor.predict(["hello", "world"])
9991

10092
assert response is not None
10193

10294

103-
def test_jumpstart_hub_estimator_with_default_session(setup, add_model_references):
95+
def test_jumpstart_hub_estimator_with_session(setup, add_model_references):
96+
10497
model_id, model_version = "huggingface-spc-bert-base-cased", "*"
10598

10699
sagemaker_session = get_sm_session()
@@ -125,12 +118,14 @@ def test_jumpstart_hub_estimator_with_default_session(setup, add_model_reference
125118
training_job_name=estimator.latest_training_job.name,
126119
model_id=model_id,
127120
model_version=model_version,
121+
sagemaker_session=get_sm_session(),
128122
)
129123

130124
# uses ml.p3.2xlarge instance
131125
predictor = estimator.deploy(
132126
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
133127
role=get_sm_session().get_caller_identity_arn(),
128+
sagemaker_session=get_sm_session(),
134129
)
135130

136131
response = predictor.predict(["hello", "world"])
@@ -144,9 +139,8 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
144139

145140
estimator = JumpStartEstimator(
146141
model_id=model_id,
147-
role=get_sm_session().get_caller_identity_arn(),
148-
sagemaker_session=get_sm_session(),
149142
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
143+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
150144
)
151145

152146
estimator.fit(
@@ -161,14 +155,11 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
161155
training_job_name=estimator.latest_training_job.name,
162156
model_id=model_id,
163157
model_version=model_version,
164-
sagemaker_session=get_sm_session(),
165158
)
166159

167160
# uses ml.p3.2xlarge instance
168161
predictor = estimator.deploy(
169162
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
170-
role=get_sm_session().get_caller_identity_arn(),
171-
sagemaker_session=get_sm_session(),
172163
)
173164

174165
response = predictor.predict(["hello", "world"])
@@ -182,9 +173,8 @@ def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references)
182173

183174
estimator = JumpStartEstimator(
184175
model_id=model_id,
185-
role=get_sm_session().get_caller_identity_arn(),
186-
sagemaker_session=get_sm_session(),
187176
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
177+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
188178
)
189179
with pytest.raises(Exception):
190180
estimator.fit(

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15553,6 +15553,8 @@
1555315553
},
1555415554
"inference_enable_network_isolation": True,
1555515555
"training_enable_network_isolation": True,
15556+
"default_training_dataset_uri": None,
15557+
"default_training_dataset_key": "training-datasets/tf_flowers/",
1555615558
"resource_name_base": "pt-ic-mobilenet-v2",
1555715559
"hosting_eula_key": None,
1555815560
"hosting_model_package_arns": {},

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def test_jumpstart_model_specs():
378378
specs1.training_script_key
379379
== "source-directory-tarballs/pytorch/transfer_learning/ic/v2.3.0/sourcedir.tar.gz"
380380
)
381+
assert specs1.default_training_dataset_key == "training-datasets/tf_flowers/"
381382
assert specs1.hyperparameters == [
382383
JumpStartHyperparameter(
383384
{

0 commit comments

Comments
 (0)