Skip to content

Commit 0a56efd

Browse files
authored
Merge branch 'master' into update-pt-2.5.1
2 parents 6f60f5c + 8dfb484 commit 0a56efd

File tree

6 files changed

+57
-179
lines changed

6 files changed

+57
-179
lines changed

src/sagemaker/jumpstart/hub/hub.py

Lines changed: 16 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,25 @@
1616
from datetime import datetime
1717
import logging
1818
from typing import Optional, Dict, List, Any, Union
19-
from botocore import exceptions
2019

2120
from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME
2221
from sagemaker.jumpstart.enums import JumpStartScriptScope
2322
from sagemaker.session import Session
2423

25-
from sagemaker.jumpstart.constants import (
26-
JUMPSTART_LOGGER,
27-
)
2824
from sagemaker.jumpstart.types import (
2925
HubContentType,
3026
)
3127
from sagemaker.jumpstart.filters import Constant, Operator, BooleanValues
3228
from sagemaker.jumpstart.hub.utils import (
3329
get_hub_model_version,
3430
get_info_from_hub_resource_arn,
35-
create_hub_bucket_if_it_does_not_exist,
36-
generate_default_hub_bucket_name,
37-
create_s3_object_reference_from_uri,
3831
construct_hub_arn_from_name,
3932
)
4033

4134
from sagemaker.jumpstart.notebook_utils import (
4235
list_jumpstart_models,
4336
)
4437

45-
from sagemaker.jumpstart.hub.types import (
46-
S3ObjectLocation,
47-
)
4838
from sagemaker.jumpstart.hub.interfaces import (
4939
DescribeHubResponse,
5040
DescribeHubContentResponse,
@@ -66,8 +56,8 @@ class Hub:
6656
def __init__(
6757
self,
6858
hub_name: str,
59+
sagemaker_session: Session,
6960
bucket_name: Optional[str] = None,
70-
sagemaker_session: Optional[Session] = None,
7161
) -> None:
7262
"""Instantiates a SageMaker ``Hub``.
7363
@@ -78,41 +68,11 @@ def __init__(
7868
"""
7969
self.hub_name = hub_name
8070
self.region = sagemaker_session.boto_region_name
71+
self.bucket_name = bucket_name
8172
self._sagemaker_session = (
8273
sagemaker_session
8374
or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True)
8475
)
85-
self.hub_storage_location = self._generate_hub_storage_location(bucket_name)
86-
87-
def _fetch_hub_bucket_name(self) -> str:
88-
"""Retrieves hub bucket name from Hub config if exists"""
89-
try:
90-
hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name)
91-
hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath")
92-
if hub_output_location:
93-
location = create_s3_object_reference_from_uri(hub_output_location)
94-
return location.bucket
95-
default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
96-
JUMPSTART_LOGGER.warning(
97-
"There is not a Hub bucket associated with %s. Using %s",
98-
self.hub_name,
99-
default_bucket_name,
100-
)
101-
return default_bucket_name
102-
except exceptions.ClientError:
103-
hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
104-
JUMPSTART_LOGGER.warning(
105-
"There is not a Hub bucket associated with %s. Using %s",
106-
self.hub_name,
107-
hub_bucket_name,
108-
)
109-
return hub_bucket_name
110-
111-
def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None:
112-
"""Generates an ``S3ObjectLocation`` given a Hub name."""
113-
hub_bucket_name = bucket_name or self._fetch_hub_bucket_name()
114-
curr_timestamp = datetime.now().timestamp()
115-
return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")
11676

11777
def _get_latest_model_version(self, model_id: str) -> str:
11878
"""Populates the lastest version of a model from specs no matter what is passed.
@@ -132,19 +92,22 @@ def create(
13292
tags: Optional[str] = None,
13393
) -> Dict[str, str]:
13494
"""Creates a hub with the given description"""
95+
curr_timestamp = datetime.now().timestamp()
13596

136-
create_hub_bucket_if_it_does_not_exist(
137-
self.hub_storage_location.bucket, self._sagemaker_session
138-
)
97+
request = {
98+
"hub_name": self.hub_name,
99+
"hub_description": description,
100+
"hub_display_name": display_name,
101+
"hub_search_keywords": search_keywords,
102+
"tags": tags,
103+
}
139104

140-
return self._sagemaker_session.create_hub(
141-
hub_name=self.hub_name,
142-
hub_description=description,
143-
hub_display_name=display_name,
144-
hub_search_keywords=search_keywords,
145-
s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()},
146-
tags=tags,
147-
)
105+
if self.bucket_name:
106+
request["s3_storage_config"] = {
107+
"S3OutputPath": (f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}")
108+
}
109+
110+
return self._sagemaker_session.create_hub(**request)
148111

149112
def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse:
150113
"""Returns descriptive information about the Hub"""

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from __future__ import absolute_import
1616
import re
1717
from typing import Optional, List, Any
18-
from sagemaker.jumpstart.hub.types import S3ObjectLocation
19-
from sagemaker.s3_utils import parse_s3_url
2018
from sagemaker.session import Session
2119
from sagemaker.utils import aws_partition
2220
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
@@ -139,61 +137,6 @@ def generate_hub_arn_for_init_kwargs(
139137
return hub_arn
140138

141139

142-
def generate_default_hub_bucket_name(
143-
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
144-
) -> str:
145-
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.
146-
147-
Returns:
148-
str: The name of the default bucket. If the name was not explicitly specified through
149-
the Session or sagemaker_config, the bucket will take the form:
150-
``sagemaker-hubs-{region}-{AWS account ID}``.
151-
"""
152-
153-
region: str = sagemaker_session.boto_region_name
154-
account_id: str = sagemaker_session.account_id()
155-
156-
# TODO: Validate and fast fail
157-
158-
return f"sagemaker-hubs-{region}-{account_id}"
159-
160-
161-
def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]:
162-
"""Utiity to help generate an S3 object reference"""
163-
if not s3_uri:
164-
return None
165-
166-
bucket, key = parse_s3_url(s3_uri)
167-
168-
return S3ObjectLocation(
169-
bucket=bucket,
170-
key=key,
171-
)
172-
173-
174-
def create_hub_bucket_if_it_does_not_exist(
175-
bucket_name: Optional[str] = None,
176-
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
177-
) -> str:
178-
"""Creates the default SageMaker Hub bucket if it does not exist.
179-
180-
Returns:
181-
str: The name of the default bucket. Takes the form:
182-
``sagemaker-hubs-{region}-{AWS account ID}``.
183-
"""
184-
185-
region: str = sagemaker_session.boto_region_name
186-
if bucket_name is None:
187-
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session)
188-
189-
sagemaker_session._create_s3_bucket_if_it_does_not_exist(
190-
bucket_name=bucket_name,
191-
region=region,
192-
)
193-
194-
return bucket_name
195-
196-
197140
def is_gated_bucket(bucket_name: str) -> bool:
198141
"""Returns true if the bucket name is the JumpStart gated bucket."""
199142
return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET

tests/integ/sagemaker/experiments/helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
from __future__ import absolute_import
1414

1515
from contextlib import contextmanager
16+
import pytest
17+
import logging
1618

1719
from sagemaker import utils
1820
from sagemaker.experiments.experiment import Experiment
21+
from sagemaker.experiments._run_context import _RunContext
1922

2023
EXP_INTEG_TEST_NAME_PREFIX = "experiments-integ"
2124

@@ -40,3 +43,16 @@ def cleanup_exp_resources(exp_names, sagemaker_session):
4043
for exp_name in exp_names:
4144
exp = Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session)
4245
exp._delete_all(action="--force")
46+
47+
@pytest.fixture
48+
def clear_run_context():
49+
current_run = _RunContext.get_current_run()
50+
if current_run == None:
51+
return
52+
53+
logging.info(
54+
f"RunContext already populated by run {current_run.run_name}"
55+
f" in experiment {current_run.experiment_name}."
56+
" Clearing context manually"
57+
)
58+
_RunContext.drop_current_run()

tests/integ/sagemaker/experiments/test_run.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from sagemaker.experiments.trial_component import _TrialComponent
3333
from sagemaker.sklearn import SKLearn
3434
from sagemaker.utils import retry_with_backoff, unique_name_from_base
35-
from tests.integ.sagemaker.experiments.helpers import name, cleanup_exp_resources
35+
from tests.integ.sagemaker.experiments.helpers import name, cleanup_exp_resources, clear_run_context
3636
from sagemaker.experiments.run import (
3737
RUN_NAME_BASE,
3838
DELIMITER,
@@ -55,7 +55,7 @@ def artifact_file_path(tempdir):
5555
metric_name = "Test-Local-Init-Log-Metric"
5656

5757

58-
def test_local_run_with_load(sagemaker_session, artifact_file_path):
58+
def test_local_run_with_load(sagemaker_session, artifact_file_path, clear_run_context):
5959
exp_name = f"My-Local-Exp-{name()}"
6060
with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session):
6161
# Run name is not provided, will create a new TC
@@ -86,7 +86,9 @@ def verify_load_run():
8686
retry_with_backoff(verify_load_run, 4)
8787

8888

89-
def test_two_local_run_init_with_same_run_name_and_different_exp_names(sagemaker_session):
89+
def test_two_local_run_init_with_same_run_name_and_different_exp_names(
90+
sagemaker_session, clear_run_context
91+
):
9092
exp_name1 = f"my-two-local-exp1-{name()}"
9193
exp_name2 = f"my-two-local-exp2-{name()}"
9294
run_name = "test-run"
@@ -124,7 +126,9 @@ def test_two_local_run_init_with_same_run_name_and_different_exp_names(sagemaker
124126
("my-test4", "test-run", "run-display-name-test"), # with supplied display name
125127
],
126128
)
127-
def test_run_name_vs_trial_component_name_edge_cases(sagemaker_session, input_names):
129+
def test_run_name_vs_trial_component_name_edge_cases(
130+
sagemaker_session, input_names, clear_run_context
131+
):
128132
exp_name, run_name, run_display_name = input_names
129133
with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session):
130134
with Run(
@@ -177,6 +181,7 @@ def test_run_from_local_and_train_job_and_all_exp_cfg_match(
177181
execution_role,
178182
sagemaker_client_config,
179183
sagemaker_metrics_config,
184+
clear_run_context,
180185
):
181186
# Notes:
182187
# 1. The 1st Run created locally and its exp config was auto passed to the job
@@ -277,6 +282,7 @@ def test_run_from_local_and_train_job_and_exp_cfg_not_match(
277282
execution_role,
278283
sagemaker_client_config,
279284
sagemaker_metrics_config,
285+
clear_run_context,
280286
):
281287
# Notes:
282288
# 1. The 1st Run created locally and its exp config was auto passed to the job
@@ -363,6 +369,7 @@ def test_run_from_train_job_only(
363369
execution_role,
364370
sagemaker_client_config,
365371
sagemaker_metrics_config,
372+
clear_run_context,
366373
):
367374
# Notes:
368375
# 1. No Run created locally or specified in experiment config
@@ -413,6 +420,7 @@ def test_run_from_processing_job_and_override_default_exp_config(
413420
execution_role,
414421
sagemaker_client_config,
415422
sagemaker_metrics_config,
423+
clear_run_context,
416424
):
417425
# Notes:
418426
# 1. The 1st Run (run) created locally
@@ -492,6 +500,7 @@ def test_run_from_transform_job(
492500
execution_role,
493501
sagemaker_client_config,
494502
sagemaker_metrics_config,
503+
clear_run_context,
495504
):
496505
# Notes:
497506
# 1. The 1st Run (run) created locally
@@ -573,6 +582,7 @@ def test_load_run_auto_pass_in_exp_config_to_job(
573582
execution_role,
574583
sagemaker_client_config,
575584
sagemaker_metrics_config,
585+
clear_run_context,
576586
):
577587
# Notes:
578588
# 1. In local side, load the Run created previously and invoke a job under the load context
@@ -621,7 +631,7 @@ def test_load_run_auto_pass_in_exp_config_to_job(
621631
)
622632

623633

624-
def test_list(run_obj, sagemaker_session):
634+
def test_list(run_obj, sagemaker_session, clear_run_context):
625635
tc1 = _TrialComponent.create(
626636
trial_component_name=f"non-run-tc1-{name()}",
627637
sagemaker_session=sagemaker_session,
@@ -643,7 +653,7 @@ def test_list(run_obj, sagemaker_session):
643653
assert run_tcs[0].experiment_config == run_obj.experiment_config
644654

645655

646-
def test_list_twice(run_obj, sagemaker_session):
656+
def test_list_twice(run_obj, sagemaker_session, clear_run_context):
647657
tc1 = _TrialComponent.create(
648658
trial_component_name=f"non-run-tc1-{name()}",
649659
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)