|
90 | 90 | get_add_model_package_inference_args, |
91 | 91 | get_update_model_package_inference_args, |
92 | 92 | ) |
| 93 | +from sagemaker.model_life_cycle import ModelLifeCycle |
93 | 94 |
|
94 | 95 | # Setting LOGGER for backward compatibility, in case users import it... |
95 | 96 | logger = LOGGER = logging.getLogger("sagemaker") |
@@ -473,6 +474,7 @@ def register( |
473 | 474 | skip_model_validation: Optional[Union[str, PipelineVariable]] = None, |
474 | 475 | source_uri: Optional[Union[str, PipelineVariable]] = None, |
475 | 476 | model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, |
| 477 | + model_life_cycle: Optional[ModelLifeCycle] = None, |
476 | 478 | accept_eula: Optional[bool] = None, |
477 | 479 | model_type: Optional[JumpStartModelType] = None, |
478 | 480 | ): |
@@ -528,6 +530,7 @@ def register( |
528 | 530 | (default: None). |
529 | 531 | model_card (ModeCard or ModelPackageModelCard): document contains qualitative and |
530 | 532 | quantitative information about a model (default: None). |
| 533 | + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). |
531 | 534 |
|
532 | 535 | Returns: |
533 | 536 | A `sagemaker.model.ModelPackage` instance or pipeline step arguments |
@@ -597,6 +600,7 @@ def register( |
597 | 600 | skip_model_validation=skip_model_validation, |
598 | 601 | source_uri=source_uri, |
599 | 602 | model_card=model_card, |
| 603 | + model_life_cycle=model_life_cycle, |
600 | 604 | ) |
601 | 605 | model_package = self.sagemaker_session.create_model_package_from_containers( |
602 | 606 | **model_pkg_args |
@@ -2385,6 +2389,23 @@ def update_source_uri( |
2385 | 2389 | sagemaker_session = self.sagemaker_session or sagemaker.Session() |
2386 | 2390 | sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args) |
2387 | 2391 |
|
| 2392 | + def update_model_life_cycle( |
| 2393 | + self, |
| 2394 | + model_life_cycle: ModelLifeCycle, |
| 2395 | + ): |
| 2396 | + """Modellifecycle to be set for the model package |
| 2397 | +
|
| 2398 | + Args: |
| 2399 | + model_life_cycle (ModelLifeCycle): The current state of model package in its life cycle |
| 2400 | +
|
| 2401 | + """ |
| 2402 | + update_model_life_cycle_args = { |
| 2403 | + "ModelPackageArn": self.model_package_arn, |
| 2404 | + "ModelLifeCycle": model_life_cycle, |
| 2405 | + } |
| 2406 | + sagemaker_session = self.sagemaker_session or sagemaker.Session() |
| 2407 | + sagemaker_session.sagemaker_client.update_model_package(**update_model_life_cycle_args) |
| 2408 | + |
2388 | 2409 | def remove_customer_metadata_properties( |
2389 | 2410 | self, customer_metadata_properties_to_remove: List[str] |
2390 | 2411 | ): |
|
0 commit comments