14
14
from __future__ import absolute_import
15
15
16
16
import copy
17
+ import re
17
18
from abc import ABC , abstractmethod
18
19
from datetime import datetime , timedelta
19
20
from typing import Type , Any , List , Dict , Optional
20
21
import logging
21
22
23
+ from botocore .exceptions import ClientError
24
+
25
+ from sagemaker .enums import Tag
22
26
from sagemaker .jumpstart import enums
23
27
from sagemaker .jumpstart .utils import verify_model_region_and_return_specs , get_eula_message
24
28
from sagemaker .model import Model
@@ -105,6 +109,7 @@ def __init__(self):
105
109
self .nb_instance_type = None
106
110
self .ram_usage_model_load = None
107
111
self .jumpstart = None
112
+ self .model_metadata = None
108
113
109
114
@abstractmethod
110
115
def _prepare_for_mode (self ):
@@ -520,6 +525,54 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]:
520
525
521
526
return self .pysdk_model .list_deployment_configs ()
522
527
528
+ def _is_fine_tuned_model (self ) -> bool :
529
+ """Checks whether a fine-tuned model exists."""
530
+ return self .model_metadata and (
531
+ self .model_metadata .get ("FINE_TUNING_MODEL_PATH" )
532
+ or self .model_metadata .get ("FINE_TUNING_JOB_NAME" )
533
+ )
534
+
535
+ def _update_model_data_for_fine_tuned_model (self , pysdk_model : Type [Model ]) -> Type [Model ]:
536
+ """Set the model path and data and add fine-tuning tags for the model."""
537
+ # TODO: determine precedence of FINE_TUNING_MODEL_PATH and FINE_TUNING_JOB_NAME
538
+ if fine_tuning_model_path := self .model_metadata .get ("FINE_TUNING_MODEL_PATH" ):
539
+ if not re .match ("^(https|s3)://([^/]+)/?(.*)$" , fine_tuning_model_path ):
540
+ raise ValueError (
541
+ f"Invalid path for FINE_TUNING_MODEL_PATH: { fine_tuning_model_path } ."
542
+ )
543
+ pysdk_model .model_data ["S3DataSource" ]["S3Uri" ] = fine_tuning_model_path
544
+ pysdk_model .add_tags (
545
+ {"key" : Tag .FINE_TUNING_MODEL_PATH , "value" : fine_tuning_model_path }
546
+ )
547
+ return pysdk_model
548
+
549
+ if fine_tuning_job_name := self .model_metadata .get ("FINE_TUNING_JOB_NAME" ):
550
+ try :
551
+ response = self .sagemaker_session .sagemaker_client .describe_training_job (
552
+ TrainingJobName = fine_tuning_job_name
553
+ )
554
+ fine_tuning_model_path = response ["OutputDataConfig" ]["S3OutputPath" ]
555
+ pysdk_model .model_data ["S3DataSource" ]["S3Uri" ] = fine_tuning_model_path
556
+ pysdk_model .model_data ["S3DataSource" ]["CompressionType" ] = response [
557
+ "OutputDataConfig"
558
+ ]["CompressionType" ]
559
+ pysdk_model .add_tags (
560
+ [
561
+ {"key" : Tag .FINE_TUNING_JOB_NAME , "value" : fine_tuning_job_name },
562
+ {"key" : Tag .FINE_TUNING_MODEL_PATH , "value" : fine_tuning_model_path },
563
+ ]
564
+ )
565
+ return pysdk_model
566
+ except ClientError :
567
+ raise ValueError (
568
+ f"Invalid job name for FINE_TUNING_JOB_NAME: { fine_tuning_job_name } ."
569
+ )
570
+
571
+ raise ValueError (
572
+ "Input model not found. Please provide either `model_path`, or "
573
+ "`FINE_TUNING_MODEL_PATH` or `FINE_TUNING_JOB_NAME` under `model_metadata`."
574
+ )
575
+
523
576
def _build_for_jumpstart (self ):
524
577
"""Placeholder docstring"""
525
578
if hasattr (self , "pysdk_model" ) and self .pysdk_model is not None :
@@ -534,6 +587,9 @@ def _build_for_jumpstart(self):
534
587
535
588
logger .info ("JumpStart ID %s is packaged with Image URI: %s" , self .model , image_uri )
536
589
590
+ if self ._is_fine_tuned_model ():
591
+ pysdk_model = self ._update_model_data_for_fine_tuned_model (pysdk_model )
592
+
537
593
if self ._is_gated_model (pysdk_model ) and self .mode != Mode .SAGEMAKER_ENDPOINT :
538
594
raise ValueError (
539
595
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
@@ -714,7 +770,7 @@ def _optimize_for_jumpstart(
714
770
** create_optimization_job_args
715
771
)
716
772
717
- def _is_gated_model (self , model ) -> bool :
773
+ def _is_gated_model (self , model : Model ) -> bool :
718
774
"""Determine if ``this`` Model is Gated
719
775
720
776
Args:
0 commit comments