|
62 | 62 | from sagemaker.serve.utils import task
|
63 | 63 | from sagemaker.serve.utils.exceptions import TaskNotFoundException
|
64 | 64 | from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
|
| 65 | +from sagemaker.serve.utils.optimize_utils import ( |
| 66 | + _is_compatible_with_compilation, |
| 67 | + _poll_optimization_job, |
| 68 | +) |
65 | 69 | from sagemaker.serve.utils.predictors import _get_local_mode_predictor
|
66 | 70 | from sagemaker.serve.utils.hardware_detector import (
|
67 | 71 | _get_gpu_info,
|
|
83 | 87 | from sagemaker.serve.validations.check_image_and_hardware_type import (
|
84 | 88 | validate_image_uri_and_hardware,
|
85 | 89 | )
|
| 90 | +from sagemaker.utils import Tags |
86 | 91 | from sagemaker.workflow.entities import PipelineVariable
|
87 | 92 | from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
|
88 | 93 |
|
@@ -804,8 +809,15 @@ def save(
|
804 | 809 | This function is available for models served by DJL serving.
|
805 | 810 |
|
806 | 811 | Args:
|
807 |
| - save_path (Optional[str]): The path where you want to save resources. |
808 |
| - s3_path (Optional[str]): The path where you want to upload resources. |
| 812 | + save_path (Optional[str]): The path where you want to save resources. Defaults to |
| 813 | + ``None``. |
| 814 | + s3_path (Optional[str]): The path where you want to upload resources. Defaults to |
| 815 | + ``None``. |
| 816 | + sagemaker_session (Optional[Session]): Session object which manages interactions |
| 817 | + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the |
| 818 | + function creates one using the default AWS configuration chain. Defaults to |
| 819 | + ``None``. |
| 820 | + role_arn (Optional[str]): The IAM role arn. Defaults to ``None``. |
809 | 821 | """
|
810 | 822 | self.sagemaker_session = sagemaker_session or Session()
|
811 | 823 |
|
@@ -915,3 +927,129 @@ def _try_fetch_gpu_info(self):
|
915 | 927 | raise ValueError(
|
916 | 928 | f"Unable to determine single GPU size for instance: [{self.instance_type}]"
|
917 | 929 | )
|
| 930 | + |
| 931 | + def optimize(self, *args, **kwargs) -> Type[Model]: |
| 932 | + """Runs a model optimization job. |
| 933 | +
|
| 934 | + Args: |
| 935 | + instance_type (str): Target deployment instance type that the model is optimized for. |
| 936 | + output_path (str): Specifies where to store the compiled/quantized model. |
| 937 | + role (Optional[str]): Execution role. Defaults to ``None``. |
| 938 | + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. |
| 939 | + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. |
| 940 | + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. |
| 941 | + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. |
| 942 | + env_vars (Optional[Dict]): Additional environment variables to run the optimization |
| 943 | + container. Defaults to ``None``. |
| 944 | + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. |
| 945 | + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading |
| 946 | + to S3. Defaults to ``None``. |
| 947 | + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to |
| 948 | + ``None``. |
| 949 | + sagemaker_session (Optional[Session]): Session object which manages interactions |
| 950 | + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the |
| 951 | + function creates one using the default AWS configuration chain. |
| 952 | +
|
| 953 | + Returns: |
| 954 | + Type[Model]: A deployable ``Model`` object. |
| 955 | + """ |
| 956 | + # need to get telemetry_opt_out info before telemetry decorator is called |
| 957 | + self.serve_settings = self._get_serve_setting() |
| 958 | + |
| 959 | + return self._model_builder_optimize_wrapper(*args, **kwargs) |
| 960 | + |
| 961 | + @_capture_telemetry("optimize") |
| 962 | + def _model_builder_optimize_wrapper( |
| 963 | + self, |
| 964 | + instance_type: str, |
| 965 | + output_path: str, |
| 966 | + role: Optional[str] = None, |
| 967 | + tags: Optional[Tags] = None, |
| 968 | + job_name: Optional[str] = None, |
| 969 | + quantization_config: Optional[Dict] = None, |
| 970 | + compilation_config: Optional[Dict] = None, |
| 971 | + env_vars: Optional[Dict] = None, |
| 972 | + vpc_config: Optional[Dict] = None, |
| 973 | + kms_key: Optional[str] = None, |
| 974 | + max_runtime_in_sec: Optional[int] = None, |
| 975 | + sagemaker_session: Optional[Session] = None, |
| 976 | + ) -> Type[Model]: |
| 977 | + """Runs a model optimization job. |
| 978 | +
|
| 979 | + Args: |
| 980 | + instance_type (str): Target deployment instance type that the model is optimized for. |
| 981 | + output_path (str): Specifies where to store the compiled/quantized model. |
| 982 | + role (Optional[str]): Execution role. Defaults to ``None``. |
| 983 | + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. |
| 984 | + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. |
| 985 | + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. |
| 986 | + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. |
| 987 | + env_vars (Optional[Dict]): Additional environment variables to run the optimization |
| 988 | + container. Defaults to ``None``. |
| 989 | + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. |
| 990 | + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading |
| 991 | + to S3. Defaults to ``None``. |
| 992 | + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to |
| 993 | + ``None``. |
| 994 | + sagemaker_session (Optional[Session]): Session object which manages interactions |
| 995 | + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the |
| 996 | + function creates one using the default AWS configuration chain. |
| 997 | +
|
| 998 | + Returns: |
| 999 | + Type[Model]: A deployable ``Model`` object. |
| 1000 | + """ |
| 1001 | + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() |
| 1002 | + |
| 1003 | + # TODO: inject actual model source location based on different scenarios |
| 1004 | + model_source = {"S3": {"S3Uri": self.model_path, "ModelAccessConfig": {"AcceptEula": True}}} |
| 1005 | + |
| 1006 | + optimization_configs = [] |
| 1007 | + if quantization_config: |
| 1008 | + optimization_configs.append({"ModelQuantizationConfig": quantization_config}) |
| 1009 | + if compilation_config: |
| 1010 | + if _is_compatible_with_compilation(instance_type): |
| 1011 | + optimization_configs.append({"ModelCompilationConfig": compilation_config}) |
| 1012 | + else: |
| 1013 | + logger.warning( |
| 1014 | + "Model compilation is currently only supported for Inferentia and Trainium" |
| 1015 | + "instances, ignoring `compilation_config'." |
| 1016 | + ) |
| 1017 | + |
| 1018 | + output_config = {"S3OutputLocation": output_path} |
| 1019 | + if kms_key: |
| 1020 | + output_config["KmsKeyId"] = kms_key |
| 1021 | + |
| 1022 | + job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" |
| 1023 | + create_optimization_job_args = { |
| 1024 | + "OptimizationJobName": job_name, |
| 1025 | + "ModelSource": model_source, |
| 1026 | + "DeploymentInstanceType": instance_type, |
| 1027 | + "OptimizationConfigs": optimization_configs, |
| 1028 | + "OutputConfig": output_config, |
| 1029 | + "RoleArn": role or self.role_arn, |
| 1030 | + } |
| 1031 | + |
| 1032 | + if env_vars: |
| 1033 | + create_optimization_job_args["OptimizationEnvironment"] = env_vars |
| 1034 | + |
| 1035 | + if max_runtime_in_sec: |
| 1036 | + create_optimization_job_args["StoppingCondition"] = { |
| 1037 | + "MaxRuntimeInSeconds": max_runtime_in_sec |
| 1038 | + } |
| 1039 | + |
| 1040 | + # TODO: tag injection if it is a JumpStart model |
| 1041 | + if tags: |
| 1042 | + create_optimization_job_args["Tags"] = tags |
| 1043 | + |
| 1044 | + if vpc_config: |
| 1045 | + create_optimization_job_args["VpcConfig"] = vpc_config |
| 1046 | + |
| 1047 | + response = self.sagemaker_session.sagemaker_client.create_optimization_job( |
| 1048 | + **create_optimization_job_args |
| 1049 | + ) |
| 1050 | + |
| 1051 | + if not _poll_optimization_job(job_name, self.sagemaker_session): |
| 1052 | + raise Exception("Optimization job timed out.") |
| 1053 | + |
| 1054 | + # TODO: return model created by optimization job |
| 1055 | + return response |
0 commit comments