|
36 | 36 | get_init_kwargs, |
37 | 37 | get_register_kwargs, |
38 | 38 | ) |
| 39 | +from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint |
39 | 40 | from sagemaker.jumpstart.types import JumpStartSerializablePayload |
40 | 41 | from sagemaker.jumpstart.utils import ( |
41 | 42 | validate_model_id_and_get_type, |
42 | 43 | verify_model_region_and_return_specs, |
43 | 44 | ) |
44 | | -from sagemaker.jumpstart.constants import JUMPSTART_LOGGER |
| 45 | +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER |
45 | 46 | from sagemaker.jumpstart.enums import JumpStartModelType |
46 | 47 | from sagemaker.model_card import ( |
47 | 48 | ModelCard, |
@@ -406,6 +407,45 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: |
406 | 407 | sagemaker_session=self.sagemaker_session, |
407 | 408 | ) |
408 | 409 |
|
| 410 | + @classmethod |
| 411 | + def attach( |
| 412 | + cls, |
| 413 | + endpoint_name: str, |
| 414 | + inference_component_name: Optional[str] = None, |
| 415 | + model_id: Optional[str] = None, |
| 416 | + model_version: Optional[str] = None, |
| 417 | + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, |
| 418 | + ) -> "JumpStartModel": |
| 419 | + """Attaches a JumpStartModel object to an existing SageMaker Endpoint. |
| 420 | +
|
| 421 | + The model id, version (and inference component name) can be inferred from the tags. |
| 422 | + """ |
| 423 | + |
| 424 | + inferred_model_id = inferred_model_version = inferred_inference_component_name = None |
| 425 | + |
| 426 | + if inference_component_name is None or model_id is None or model_version is None: |
| 427 | + inferred_model_id, inferred_model_version, inferred_inference_component_name = ( |
| 428 | + get_model_id_version_from_endpoint( |
| 429 | + endpoint_name=endpoint_name, |
| 430 | + inference_component_name=inference_component_name, |
| 431 | + sagemaker_session=sagemaker_session, |
| 432 | + ) |
| 433 | + ) |
| 434 | + |
| 435 | + model_id = model_id or inferred_model_id |
| 436 | + model_version = model_version or inferred_model_version or "*" |
| 437 | + inference_component_name = inference_component_name or inferred_inference_component_name |
| 438 | + |
| 439 | + model = JumpStartModel( |
| 440 | + model_id=model_id, |
| 441 | + model_version=model_version, |
| 442 | + sagemaker_session=sagemaker_session, |
| 443 | + ) |
| 444 | + model.endpoint_name = endpoint_name |
| 445 | + model.inference_component_name = inference_component_name |
| 446 | + |
| 447 | + return model |
| 448 | + |
409 | 449 | def _create_sagemaker_model( |
410 | 450 | self, |
411 | 451 | instance_type=None, |
@@ -484,6 +524,7 @@ def deploy( |
484 | 524 | deserializer: Optional[BaseDeserializer] = None, |
485 | 525 | accelerator_type: Optional[str] = None, |
486 | 526 | endpoint_name: Optional[str] = None, |
| 527 | + inference_component_name: Optional[str] = None, |
487 | 528 | tags: Optional[Tags] = None, |
488 | 529 | kms_key: Optional[str] = None, |
489 | 530 | wait: Optional[bool] = True, |
@@ -614,6 +655,7 @@ def deploy( |
614 | 655 | deserializer=deserializer, |
615 | 656 | accelerator_type=accelerator_type, |
616 | 657 | endpoint_name=endpoint_name, |
| 658 | + inference_component_name=inference_component_name, |
617 | 659 | tags=format_tags(tags), |
618 | 660 | kms_key=kms_key, |
619 | 661 | wait=wait, |
|
0 commit comments