Skip to content

Commit 997e2ce

Browse files
grenmesterJacky Lee
andauthored
feat: add build/deploy support for fine-tuned JS models (#1473)
* feat: add support for fine-tuned JS models * Refactor * Refactor * Refactor * Refactor * pylint * pylint --------- Co-authored-by: Jacky Lee <[email protected]>
1 parent c6581ff commit 997e2ce

File tree

3 files changed

+68
-8
lines changed

3 files changed

+68
-8
lines changed

src/sagemaker/enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,5 @@ class Tag(str, Enum):
4646
"""Enum class for tag keys to apply to models."""
4747

4848
OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name"
49+
FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path"
50+
FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name"

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
from __future__ import absolute_import
1515

1616
import copy
17+
import re
1718
from abc import ABC, abstractmethod
1819
from datetime import datetime, timedelta
1920
from typing import Type, Any, List, Dict, Optional
2021
import logging
2122

23+
from botocore.exceptions import ClientError
24+
25+
from sagemaker.enums import Tag
2226
from sagemaker.jumpstart import enums
2327
from sagemaker.jumpstart.utils import verify_model_region_and_return_specs, get_eula_message
2428
from sagemaker.model import Model
@@ -105,6 +109,7 @@ def __init__(self):
105109
self.nb_instance_type = None
106110
self.ram_usage_model_load = None
107111
self.jumpstart = None
112+
self.model_metadata = None
108113

109114
@abstractmethod
110115
def _prepare_for_mode(self):
@@ -520,6 +525,54 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]:
520525

521526
return self.pysdk_model.list_deployment_configs()
522527

528+
def _is_fine_tuned_model(self) -> bool:
529+
"""Checks whether a fine-tuned model exists."""
530+
return self.model_metadata and (
531+
self.model_metadata.get("FINE_TUNING_MODEL_PATH")
532+
or self.model_metadata.get("FINE_TUNING_JOB_NAME")
533+
)
534+
535+
def _update_model_data_for_fine_tuned_model(self, pysdk_model: Type[Model]) -> Type[Model]:
536+
"""Set the model path and data and add fine-tuning tags for the model."""
537+
# TODO: determine precedence of FINE_TUNING_MODEL_PATH and FINE_TUNING_JOB_NAME
538+
if fine_tuning_model_path := self.model_metadata.get("FINE_TUNING_MODEL_PATH"):
539+
if not re.match("^(https|s3)://([^/]+)/?(.*)$", fine_tuning_model_path):
540+
raise ValueError(
541+
f"Invalid path for FINE_TUNING_MODEL_PATH: {fine_tuning_model_path}."
542+
)
543+
pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path
544+
pysdk_model.add_tags(
545+
{"key": Tag.FINE_TUNING_MODEL_PATH, "value": fine_tuning_model_path}
546+
)
547+
return pysdk_model
548+
549+
if fine_tuning_job_name := self.model_metadata.get("FINE_TUNING_JOB_NAME"):
550+
try:
551+
response = self.sagemaker_session.sagemaker_client.describe_training_job(
552+
TrainingJobName=fine_tuning_job_name
553+
)
554+
fine_tuning_model_path = response["OutputDataConfig"]["S3OutputPath"]
555+
pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path
556+
pysdk_model.model_data["S3DataSource"]["CompressionType"] = response[
557+
"OutputDataConfig"
558+
]["CompressionType"]
559+
pysdk_model.add_tags(
560+
[
561+
{"key": Tag.FINE_TUNING_JOB_NAME, "value": fine_tuning_job_name},
562+
{"key": Tag.FINE_TUNING_MODEL_PATH, "value": fine_tuning_model_path},
563+
]
564+
)
565+
return pysdk_model
566+
except ClientError:
567+
raise ValueError(
568+
f"Invalid job name for FINE_TUNING_JOB_NAME: {fine_tuning_job_name}."
569+
)
570+
571+
raise ValueError(
572+
"Input model not found. Please provide either `model_path`, or "
573+
"`FINE_TUNING_MODEL_PATH` or `FINE_TUNING_JOB_NAME` under `model_metadata`."
574+
)
575+
523576
def _build_for_jumpstart(self):
524577
"""Placeholder docstring"""
525578
if hasattr(self, "pysdk_model") and self.pysdk_model is not None:
@@ -534,6 +587,9 @@ def _build_for_jumpstart(self):
534587

535588
logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)
536589

590+
if self._is_fine_tuned_model():
591+
pysdk_model = self._update_model_data_for_fine_tuned_model(pysdk_model)
592+
537593
if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
538594
raise ValueError(
539595
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
@@ -714,7 +770,7 @@ def _optimize_for_jumpstart(
714770
**create_optimization_job_args
715771
)
716772

717-
def _is_gated_model(self, model) -> bool:
773+
def _is_gated_model(self, model: Model) -> bool:
718774
"""Determine if ``this`` Model is Gated
719775
720776
Args:

src/sagemaker/serve/builder/model_builder.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Holds the ModelBuilder class and the ModelServer enum."""
1414
from __future__ import absolute_import
15+
1516
import uuid
1617
from typing import Any, Type, List, Dict, Optional, Union
1718
from dataclasses import dataclass, field
@@ -278,8 +279,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
278279
default=None,
279280
metadata={
280281
"help": "Define the model metadata to override, currently supports `HF_TASK`, "
281-
"`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in "
282-
"the Hub, Adding unsupported task types will throw an exception"
282+
"`MLFLOW_MODEL_PATH`, `FINE_TUNING_MODEL_PATH`, and `FINE_TUNING_JOB_NAME`. HF_TASK "
283+
"should be set for new models without task metadata in the Hub, Adding unsupported "
284+
"task types will throw an exception."
283285
},
284286
)
285287

@@ -739,8 +741,8 @@ def build( # pylint: disable=R0911
739741
)
740742

741743
self.serve_settings = self._get_serve_setting()
742-
743744
self._is_custom_image_uri = self.image_uri is not None
745+
744746
self._is_mlflow_model = self._check_if_input_is_mlflow_model()
745747
if self._is_mlflow_model:
746748
logger.warning(
@@ -925,7 +927,7 @@ def _try_fetch_gpu_info(self):
925927
f"Unable to determine single GPU size for instance: [{self.instance_type}]"
926928
)
927929

928-
def optimize(self, *args, **kwargs) -> Type[Model]:
930+
def optimize(self, *args, **kwargs) -> Model:
929931
"""Runs a model optimization job.
930932
931933
Args:
@@ -948,7 +950,7 @@ def optimize(self, *args, **kwargs) -> Type[Model]:
948950
function creates one using the default AWS configuration chain.
949951
950952
Returns:
951-
Type[Model]: A deployable ``Model`` object.
953+
Model: A deployable ``Model`` object.
952954
"""
953955
# need to get telemetry_opt_out info before telemetry decorator is called
954956
self.serve_settings = self._get_serve_setting()
@@ -972,7 +974,7 @@ def _model_builder_optimize_wrapper(
972974
kms_key: Optional[str] = None,
973975
max_runtime_in_sec: Optional[int] = None,
974976
sagemaker_session: Optional[Session] = None,
975-
) -> Type[Model]:
977+
) -> Model:
976978
"""Runs a model optimization job.
977979
978980
Args:
@@ -1002,7 +1004,7 @@ def _model_builder_optimize_wrapper(
10021004
function creates one using the default AWS configuration chain.
10031005
10041006
Returns:
1005-
Type[Model]: A deployable ``Model`` object.
1007+
Model: A deployable ``Model`` object.
10061008
"""
10071009
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
10081010
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)

0 commit comments

Comments
 (0)