Skip to content

Commit 6462f05

Browse files
author
isha chidrawar
committed
Merge branch 'master' of github.com:IshaChid76/sagemaker-python-sdk
2 parents 0921a1c + 8dfb484 commit 6462f05

File tree

20 files changed

+505
-197
lines changed

20 files changed

+505
-197
lines changed

src/sagemaker/_studio.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ def _find_config(working_dir=None):
6565
wd = Path(working_dir) if working_dir else Path.cwd()
6666

6767
path = None
68-
while path is None and not wd.match("/"):
68+
69+
# Get the root of the current working directory for both Windows and Unix-like systems
70+
root = Path(wd.anchor)
71+
while path is None and wd != root:
6972
candidate = wd / STUDIO_PROJECT_CONFIG
7073
if Path.exists(candidate):
7174
path = candidate

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@
8585
"2.2": "2.2.0",
8686
"2.3": "2.3.0",
8787
"2.4": "2.4.0",
88-
"2.5": "2.5.1"
88+
"2.5": "2.5.1",
89+
"2.6": "2.6.0"
8990
},
9091
"versions": {
9192
"0.4.0": {
@@ -1253,6 +1254,50 @@
12531254
"us-west-2": "763104351884"
12541255
},
12551256
"repository": "pytorch-inference"
1257+
},
1258+
"2.6.0": {
1259+
"py_versions": [
1260+
"py312"
1261+
],
1262+
"registries": {
1263+
"af-south-1": "626614931356",
1264+
"ap-east-1": "871362719292",
1265+
"ap-northeast-1": "763104351884",
1266+
"ap-northeast-2": "763104351884",
1267+
"ap-northeast-3": "364406365360",
1268+
"ap-south-1": "763104351884",
1269+
"ap-south-2": "772153158452",
1270+
"ap-southeast-1": "763104351884",
1271+
"ap-southeast-2": "763104351884",
1272+
"ap-southeast-3": "907027046896",
1273+
"ap-southeast-4": "457447274322",
1274+
"ap-southeast-5": "550225433462",
1275+
"ap-southeast-7": "590183813437",
1276+
"ca-central-1": "763104351884",
1277+
"ca-west-1": "204538143572",
1278+
"cn-north-1": "727897471807",
1279+
"cn-northwest-1": "727897471807",
1280+
"eu-central-1": "763104351884",
1281+
"eu-central-2": "380420809688",
1282+
"eu-north-1": "763104351884",
1283+
"eu-south-1": "692866216735",
1284+
"eu-south-2": "503227376785",
1285+
"eu-west-1": "763104351884",
1286+
"eu-west-2": "763104351884",
1287+
"eu-west-3": "763104351884",
1288+
"il-central-1": "780543022126",
1289+
"me-central-1": "914824155844",
1290+
"me-south-1": "217643126080",
1291+
"mx-central-1": "637423239942",
1292+
"sa-east-1": "763104351884",
1293+
"us-east-1": "763104351884",
1294+
"us-east-2": "763104351884",
1295+
"us-gov-east-1": "446045086412",
1296+
"us-gov-west-1": "442386744353",
1297+
"us-west-1": "763104351884",
1298+
"us-west-2": "763104351884"
1299+
},
1300+
"repository": "pytorch-inference"
12561301
}
12571302
}
12581303
},
@@ -1628,7 +1673,8 @@
16281673
"2.2": "2.2.0",
16291674
"2.3": "2.3.0",
16301675
"2.4": "2.4.0",
1631-
"2.5": "2.5.1"
1676+
"2.5": "2.5.1",
1677+
"2.6": "2.6.0"
16321678
},
16331679
"versions": {
16341680
"0.4.0": {
@@ -2801,6 +2847,50 @@
28012847
"us-west-2": "763104351884"
28022848
},
28032849
"repository": "pytorch-training"
2850+
},
2851+
"2.6.0": {
2852+
"py_versions": [
2853+
"py312"
2854+
],
2855+
"registries": {
2856+
"af-south-1": "626614931356",
2857+
"ap-east-1": "871362719292",
2858+
"ap-northeast-1": "763104351884",
2859+
"ap-northeast-2": "763104351884",
2860+
"ap-northeast-3": "364406365360",
2861+
"ap-south-1": "763104351884",
2862+
"ap-south-2": "772153158452",
2863+
"ap-southeast-1": "763104351884",
2864+
"ap-southeast-2": "763104351884",
2865+
"ap-southeast-3": "907027046896",
2866+
"ap-southeast-4": "457447274322",
2867+
"ap-southeast-5": "550225433462",
2868+
"ap-southeast-7": "590183813437",
2869+
"ca-central-1": "763104351884",
2870+
"ca-west-1": "204538143572",
2871+
"cn-north-1": "727897471807",
2872+
"cn-northwest-1": "727897471807",
2873+
"eu-central-1": "763104351884",
2874+
"eu-central-2": "380420809688",
2875+
"eu-north-1": "763104351884",
2876+
"eu-south-1": "692866216735",
2877+
"eu-south-2": "503227376785",
2878+
"eu-west-1": "763104351884",
2879+
"eu-west-2": "763104351884",
2880+
"eu-west-3": "763104351884",
2881+
"il-central-1": "780543022126",
2882+
"me-central-1": "914824155844",
2883+
"me-south-1": "217643126080",
2884+
"mx-central-1": "637423239942",
2885+
"sa-east-1": "763104351884",
2886+
"us-east-1": "763104351884",
2887+
"us-east-2": "763104351884",
2888+
"us-gov-east-1": "446045086412",
2889+
"us-gov-west-1": "442386744353",
2890+
"us-west-1": "763104351884",
2891+
"us-west-2": "763104351884"
2892+
},
2893+
"repository": "pytorch-training"
28042894
}
28052895
}
28062896
}

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
JUMPSTART_LOGGER,
5757
TRAINING_ENTRY_POINT_SCRIPT_NAME,
5858
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
59+
JUMPSTART_MODEL_HUB_NAME,
5960
)
6061
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
6162
from sagemaker.jumpstart.factory import model
@@ -313,16 +314,31 @@ def _add_hub_access_config_to_kwargs_inputs(
313314
):
314315
"""Adds HubAccessConfig to kwargs inputs"""
315316

317+
dataset_uri = kwargs.specs.default_training_dataset_uri
316318
if isinstance(kwargs.inputs, str):
317-
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
319+
if dataset_uri is not None and dataset_uri == kwargs.inputs:
320+
kwargs.inputs = TrainingInput(
321+
s3_data=kwargs.inputs, hub_access_config=hub_access_config
322+
)
318323
elif isinstance(kwargs.inputs, TrainingInput):
319-
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
324+
if (
325+
dataset_uri is not None
326+
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
327+
):
328+
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
320329
elif isinstance(kwargs.inputs, dict):
321330
for k, v in kwargs.inputs.items():
322331
if isinstance(v, str):
323-
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
332+
training_input = TrainingInput(s3_data=v)
333+
if dataset_uri is not None and dataset_uri == v:
334+
training_input.add_hub_access_config(hub_access_config=hub_access_config)
335+
kwargs.inputs[k] = training_input
324336
elif isinstance(kwargs.inputs, TrainingInput):
325-
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
337+
if (
338+
dataset_uri is not None
339+
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
340+
):
341+
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
326342

327343
return kwargs
328344

@@ -616,8 +632,13 @@ def _add_model_reference_arn_to_kwargs(
616632

617633
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
618634
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
619-
620-
if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)):
635+
# hub_arn is by default None unless the user specifies the hub_name
636+
# If no hub_name is specified, it is assumed the public hub
637+
is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False
638+
if (
639+
_model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs))
640+
or is_private_hub
641+
):
621642
default_model_uri = model_uris.retrieve(
622643
model_scope=JumpStartScriptScope.TRAINING,
623644
instance_type=kwargs.instance_type,

src/sagemaker/jumpstart/hub/hub.py

Lines changed: 16 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,25 @@
1616
from datetime import datetime
1717
import logging
1818
from typing import Optional, Dict, List, Any, Union
19-
from botocore import exceptions
2019

2120
from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME
2221
from sagemaker.jumpstart.enums import JumpStartScriptScope
2322
from sagemaker.session import Session
2423

25-
from sagemaker.jumpstart.constants import (
26-
JUMPSTART_LOGGER,
27-
)
2824
from sagemaker.jumpstart.types import (
2925
HubContentType,
3026
)
3127
from sagemaker.jumpstart.filters import Constant, Operator, BooleanValues
3228
from sagemaker.jumpstart.hub.utils import (
3329
get_hub_model_version,
3430
get_info_from_hub_resource_arn,
35-
create_hub_bucket_if_it_does_not_exist,
36-
generate_default_hub_bucket_name,
37-
create_s3_object_reference_from_uri,
3831
construct_hub_arn_from_name,
3932
)
4033

4134
from sagemaker.jumpstart.notebook_utils import (
4235
list_jumpstart_models,
4336
)
4437

45-
from sagemaker.jumpstart.hub.types import (
46-
S3ObjectLocation,
47-
)
4838
from sagemaker.jumpstart.hub.interfaces import (
4939
DescribeHubResponse,
5040
DescribeHubContentResponse,
@@ -66,8 +56,8 @@ class Hub:
6656
def __init__(
6757
self,
6858
hub_name: str,
59+
sagemaker_session: Session,
6960
bucket_name: Optional[str] = None,
70-
sagemaker_session: Optional[Session] = None,
7161
) -> None:
7262
"""Instantiates a SageMaker ``Hub``.
7363
@@ -78,41 +68,11 @@ def __init__(
7868
"""
7969
self.hub_name = hub_name
8070
self.region = sagemaker_session.boto_region_name
71+
self.bucket_name = bucket_name
8172
self._sagemaker_session = (
8273
sagemaker_session
8374
or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True)
8475
)
85-
self.hub_storage_location = self._generate_hub_storage_location(bucket_name)
86-
87-
def _fetch_hub_bucket_name(self) -> str:
88-
"""Retrieves hub bucket name from Hub config if exists"""
89-
try:
90-
hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name)
91-
hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath")
92-
if hub_output_location:
93-
location = create_s3_object_reference_from_uri(hub_output_location)
94-
return location.bucket
95-
default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
96-
JUMPSTART_LOGGER.warning(
97-
"There is not a Hub bucket associated with %s. Using %s",
98-
self.hub_name,
99-
default_bucket_name,
100-
)
101-
return default_bucket_name
102-
except exceptions.ClientError:
103-
hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
104-
JUMPSTART_LOGGER.warning(
105-
"There is not a Hub bucket associated with %s. Using %s",
106-
self.hub_name,
107-
hub_bucket_name,
108-
)
109-
return hub_bucket_name
110-
111-
def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None:
112-
"""Generates an ``S3ObjectLocation`` given a Hub name."""
113-
hub_bucket_name = bucket_name or self._fetch_hub_bucket_name()
114-
curr_timestamp = datetime.now().timestamp()
115-
return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")
11676

11777
def _get_latest_model_version(self, model_id: str) -> str:
11878
"""Populates the lastest version of a model from specs no matter what is passed.
@@ -132,19 +92,22 @@ def create(
13292
tags: Optional[str] = None,
13393
) -> Dict[str, str]:
13494
"""Creates a hub with the given description"""
95+
curr_timestamp = datetime.now().timestamp()
13596

136-
create_hub_bucket_if_it_does_not_exist(
137-
self.hub_storage_location.bucket, self._sagemaker_session
138-
)
97+
request = {
98+
"hub_name": self.hub_name,
99+
"hub_description": description,
100+
"hub_display_name": display_name,
101+
"hub_search_keywords": search_keywords,
102+
"tags": tags,
103+
}
139104

140-
return self._sagemaker_session.create_hub(
141-
hub_name=self.hub_name,
142-
hub_description=description,
143-
hub_display_name=display_name,
144-
hub_search_keywords=search_keywords,
145-
s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()},
146-
tags=tags,
147-
)
105+
if self.bucket_name:
106+
request["s3_storage_config"] = {
107+
"S3OutputPath": (f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}")
108+
}
109+
110+
return self._sagemaker_session.create_hub(**request)
148111

149112
def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse:
150113
"""Returns descriptive information about the Hub"""

src/sagemaker/jumpstart/hub/interfaces.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
630630
if json_obj.get("ValidationSupported")
631631
else None
632632
)
633-
self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri")
634633
self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase")
635634
self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False))
636635
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
@@ -671,6 +670,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
671670
)
672671

673672
if self.training_supported:
673+
self.default_training_dataset_uri: Optional[str] = json_obj.get(
674+
"DefaultTrainingDatasetUri"
675+
)
674676
self.training_model_package_artifact_uri: Optional[str] = json_obj.get(
675677
"TrainingModelPackageArtifactUri"
676678
)

src/sagemaker/jumpstart/hub/parsers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,10 @@ def make_model_specs_from_describe_hub_content_response(
279279
specs["training_instance_type_variants"] = (
280280
hub_model_document.training_instance_type_variants
281281
)
282+
if hub_model_document.default_training_dataset_uri:
283+
_, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable
284+
hub_model_document.default_training_dataset_uri
285+
)
286+
specs["default_training_dataset_key"] = default_training_dataset_key
287+
specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri
282288
return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)

0 commit comments

Comments
 (0)