29
29
_retrieve_model_package_model_artifact_s3_uri ,
30
30
)
31
31
from sagemaker .jumpstart .artifacts .resource_names import _retrieve_resource_name_base
32
+ from sagemaker .jumpstart .hub .utils import construct_hub_model_arn_from_inputs , construct_hub_model_reference_arn_from_inputs
32
33
from sagemaker .session import Session
33
34
from sagemaker .async_inference .async_inference_config import AsyncInferenceConfig
34
35
from sagemaker .base_deserializers import BaseDeserializer
52
53
from sagemaker .jumpstart .enums import JumpStartScriptScope , JumpStartModelType
53
54
from sagemaker .jumpstart .factory import model
54
55
from sagemaker .jumpstart .types import (
56
+ HubContentType ,
55
57
JumpStartEstimatorDeployKwargs ,
56
58
JumpStartEstimatorFitKwargs ,
57
59
JumpStartEstimatorInitKwargs ,
@@ -201,6 +203,11 @@ def get_init_kwargs(
201
203
estimator_init_kwargs = _add_region_to_kwargs (estimator_init_kwargs )
202
204
estimator_init_kwargs = _add_instance_type_and_count_to_kwargs (estimator_init_kwargs )
203
205
estimator_init_kwargs = _add_image_uri_to_kwargs (estimator_init_kwargs )
206
+ if hub_arn :
207
+ estimator_init_kwargs = _add_model_reference_arn_to_kwargs (kwargs = estimator_init_kwargs )
208
+ else :
209
+ estimator_init_kwargs .model_reference_arn = None
210
+ estimator_init_kwargs .hub_content_type = None
204
211
estimator_init_kwargs = _add_model_uri_to_kwargs (estimator_init_kwargs )
205
212
estimator_init_kwargs = _add_source_dir_to_kwargs (estimator_init_kwargs )
206
213
estimator_init_kwargs = _add_entry_point_to_kwargs (estimator_init_kwargs )
@@ -511,7 +518,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
511
518
)
512
519
513
520
if kwargs .hub_arn :
514
- kwargs .tags = add_hub_content_arn_tags (kwargs .tags , kwargs .hub_arn )
521
+ if kwargs .model_reference_arn :
522
+ hub_content_arn = construct_hub_model_reference_arn_from_inputs (
523
+ kwargs .hub_arn , kwargs .model_id , kwargs .model_version
524
+ )
525
+ else :
526
+ hub_content_arn = construct_hub_model_arn_from_inputs (
527
+ kwargs .hub_arn , kwargs .model_id , kwargs .model_version
528
+ )
529
+ kwargs .tags = add_hub_content_arn_tags (kwargs .tags , hub_content_arn = hub_content_arn )
515
530
516
531
return kwargs
517
532
@@ -534,6 +549,32 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
534
549
535
550
return kwargs
536
551
552
+ def _add_model_reference_arn_to_kwargs (
553
+ kwargs : JumpStartEstimatorInitKwargs ,
554
+ ) -> JumpStartEstimatorInitKwargs :
555
+ """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""
556
+
557
+ hub_content_type = verify_model_region_and_return_specs (
558
+ model_id = kwargs .model_id ,
559
+ version = kwargs .model_version ,
560
+ hub_arn = kwargs .hub_arn ,
561
+ scope = JumpStartScriptScope .TRAINING ,
562
+ region = kwargs .region ,
563
+ tolerate_vulnerable_model = kwargs .tolerate_vulnerable_model ,
564
+ tolerate_deprecated_model = kwargs .tolerate_deprecated_model ,
565
+ sagemaker_session = kwargs .sagemaker_session ,
566
+ model_type = kwargs .model_type ,
567
+ ).hub_content_type
568
+ kwargs .hub_content_type = hub_content_type if kwargs .hub_arn else None
569
+
570
+ if hub_content_type == HubContentType .MODEL_REFERENCE :
571
+ kwargs .model_reference_arn = construct_hub_model_reference_arn_from_inputs (
572
+ hub_arn = kwargs .hub_arn , model_name = kwargs .model_id , version = kwargs .model_version
573
+ )
574
+ else :
575
+ kwargs .model_reference_arn = None
576
+ return kwargs
577
+
537
578
538
579
def _add_model_uri_to_kwargs (kwargs : JumpStartEstimatorInitKwargs ) -> JumpStartEstimatorInitKwargs :
539
580
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
0 commit comments