Skip to content

Commit 9a410e5

Browse files
authored
feat: Added utils for extracting JS data sources (#1471)
* added utils for accessing hosting data sources * added utils for accessing hosting data sources * removed other changes * fixed formatting issues * remove .keys() * updated JumpStartModelDataSource * fix slots * remove print * fix tests * update tests
1 parent 0151209 commit 9a410e5

File tree

3 files changed

+66
-26
lines changed

3 files changed

+66
-26
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -884,41 +884,19 @@ def to_json(self) -> Dict[str, Any]:
884884
return json_obj
885885

886886

887-
class JumpStartModelDataSource(JumpStartDataHolderType):
887+
class JumpStartModelDataSource(AdditionalModelDataSource):
888888
"""Data class JumpStart additional model data source."""
889889

890-
__slots__ = ["version", "additional_model_data_source"]
891-
892-
def __init__(self, spec: Dict[str, Any]):
893-
"""Initializes a JumpStartModelDataSource object.
894-
895-
Args:
896-
spec (Dict[str, Any]): Dictionary representation of data source.
897-
"""
898-
self.from_json(spec)
890+
__slots__ = ["artifact_version"] + AdditionalModelDataSource.__slots__
899891

900892
def from_json(self, json_obj: Dict[str, Any]) -> None:
901893
"""Sets fields in object based on json.
902894
903895
Args:
904896
json_obj (Dict[str, Any]): Dictionary representation of data source.
905897
"""
906-
self.version: str = json_obj["artifact_version"]
907-
self.additional_model_data_source: AdditionalModelDataSource = AdditionalModelDataSource(
908-
json_obj
909-
)
910-
911-
def to_json(self) -> Dict[str, Any]:
912-
"""Returns json representation of JumpStartModelDataSource object."""
913-
json_obj = {}
914-
for att in self.__slots__:
915-
if hasattr(self, att):
916-
cur_val = getattr(self, att)
917-
if issubclass(type(cur_val), JumpStartDataHolderType):
918-
json_obj[att] = cur_val.to_json()
919-
else:
920-
json_obj[att] = cur_val
921-
return json_obj
898+
super().from_json(json_obj)
899+
self.artifact_version: str = json_obj["artifact_version"]
922900

923901

924902
class JumpStartAdditionalDataSources(JumpStartDataHolderType):
@@ -1655,6 +1633,19 @@ def supports_incremental_training(self) -> bool:
16551633
"""Returns True if the model supports incremental training."""
16561634
return self.incremental_training_supported
16571635

1636+
def get_speculative_decoding_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]:
1637+
"""Returns data sources for speculative decoding."""
1638+
return self.hosting_additional_data_sources.speculative_decoding or []
1639+
1640+
def get_additional_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]:
1641+
"""Returns a list of the additional S3 data sources for use by the model."""
1642+
additional_data_sources = []
1643+
if self.hosting_additional_data_sources:
1644+
for data_source in self.hosting_additional_data_sources.to_json():
1645+
data_sources = getattr(self.hosting_additional_data_sources, data_source) or []
1646+
additional_data_sources.extend(data_sources)
1647+
return additional_data_sources
1648+
16581649

16591650
class JumpStartVersionedModelId(JumpStartDataHolderType):
16601651
"""Data class for versioned model IDs."""

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7518,6 +7518,37 @@
75187518
"hosting_additional_data_sources": None,
75197519
}
75207520

7521+
BASE_HOSTING_ADDITIONAL_DATA_SOURCES = {
7522+
"hosting_additional_data_sources": {
7523+
"speculative_decoding": [
7524+
{
7525+
"channel_name": "speculative_decoding_channel",
7526+
"artifact_version": "version",
7527+
"s3_data_source": {
7528+
"compression_type": "None",
7529+
"s3_data_type": "S3Prefix",
7530+
"s3_uri": "s3://bucket/path1",
7531+
"hub_access_config": None,
7532+
"model_access_config": None,
7533+
},
7534+
}
7535+
],
7536+
"scripts": [
7537+
{
7538+
"channel_name": "scripts_channel",
7539+
"artifact_version": "version",
7540+
"s3_data_source": {
7541+
"compression_type": "None",
7542+
"s3_data_type": "S3Prefix",
7543+
"s3_uri": "s3://bucket/path1",
7544+
"hub_access_config": None,
7545+
"model_access_config": None,
7546+
},
7547+
}
7548+
],
7549+
},
7550+
}
7551+
75217552
BASE_HEADER = {
75227553
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
75237554
"version": "1.0.0",

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from tests.unit.sagemaker.jumpstart.constants import (
3030
BASE_SPEC,
31+
BASE_HOSTING_ADDITIONAL_DATA_SOURCES,
3132
INFERENCE_CONFIG_RANKINGS,
3233
INFERENCE_CONFIGS,
3334
TRAINING_CONFIG_RANKINGS,
@@ -436,6 +437,23 @@ def test_jumpstart_model_specs():
436437
assert specs3 == specs1
437438

438439

440+
def test_get_speculative_decoding_s3_data_sources():
441+
specs = JumpStartModelSpecs({**BASE_SPEC, **BASE_HOSTING_ADDITIONAL_DATA_SOURCES})
442+
assert (
443+
specs.get_speculative_decoding_s3_data_sources()
444+
== specs.hosting_additional_data_sources.speculative_decoding
445+
)
446+
447+
448+
def test_get_additional_s3_data_sources():
449+
specs = JumpStartModelSpecs({**BASE_SPEC, **BASE_HOSTING_ADDITIONAL_DATA_SOURCES})
450+
data_sources = [
451+
*specs.hosting_additional_data_sources.speculative_decoding,
452+
*specs.hosting_additional_data_sources.scripts,
453+
]
454+
assert specs.get_additional_s3_data_sources() == data_sources
455+
456+
439457
def test_jumpstart_image_uri_instance_variants():
440458

441459
assert (

0 commit comments

Comments
 (0)