diff --git a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py index de1a81ed..db75afbc 100644 --- a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py +++ b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py @@ -133,9 +133,15 @@ class LLMContentHandler(ContentHandlerBase[str, str]): class SagemakerEndpoint(LLM): """Sagemaker Inference Endpoint models. - To use, you must supply the endpoint name from your deployed + To use with a pre-deployed SageMaker endpoint or inference component, you must + supply the endpoint name and optional inference component name from your deployed Sagemaker model & the region where it is deployed. + To use with undeployed SageMaker resources, you can supply an endpoint name, + optional inference component name, and deployment configuration which defines + the endpoint and model configs. This construct can then be used by the SageMaker + PythonSDK ModelBuilder class to deploy a Sagemaker model on the desired compute. + To authenticate, the AWS client uses the following methods to automatically load credentials: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html @@ -191,8 +197,7 @@ class SagemakerEndpoint(LLM): region_name=region_name, credentials_profile_name=credentials_profile_name ) - - #Use with boto3 client + # Usage with boto3 client client = boto3.client( "sagemaker-runtime", region_name=region_name @@ -208,7 +213,7 @@ class SagemakerEndpoint(LLM): """Boto3 client for sagemaker runtime""" endpoint_name: str = "" - """The name of the endpoint from the deployed Sagemaker model. + """The name of the endpoint created from a Sagemaker model. Must be unique within an AWS Region.""" inference_component_name: Optional[str] = None @@ -263,6 +268,33 @@ def transform_output(self, output: bytes) -> str: .. _boto3: """ + deployment_config: Optional[Dict] = None + """The deployment configuration for an undeployed endpoint or inference component + which can be deployed through the Sagemaker Python SDK ModelBuilder class. + Comprises two sub-dictionaries model_config and endpoint_config. + """ + + """ + Schema: + .. code-block:: python + deployment_config = { + "model_config": { + "model": Optional[str], + "model_path": Optional[str], + "image_uri": Optional[str], + "model_server": Optional[str], + "content_type": Optional[str], + "accept_type": Optional[str] + }, + "endpoint_config": { + "resources": Optional[Dict[str, int]], + "instance_type": Optional[str], + "initial_instance_count": Optional[int] + }, + "tags": Optional[List[Dict]] + } + """ + model_config = ConfigDict( extra="forbid", )