10
10
"""
11
11
12
12
import types
13
- from typing import List , Optional
13
+ from typing import List , Optional , Union
14
14
15
15
from litellm .types .llms .bedrock import (
16
16
AmazonTitanV2EmbeddingRequest ,
@@ -30,9 +30,7 @@ class AmazonTitanV2Config:
30
30
normalize : Optional [bool ] = None
31
31
dimensions : Optional [int ] = None
32
32
33
- def __init__ (
34
- self , normalize : Optional [bool ] = None , dimensions : Optional [int ] = None
35
- ) -> None :
33
+ def __init__ (self , normalize : Optional [bool ] = None , dimensions : Optional [int ] = None ) -> None :
36
34
locals_ = locals ().copy ()
37
35
for key , value in locals_ .items ():
38
36
if key != "self" and value is not None :
@@ -57,32 +55,56 @@ def get_config(cls):
57
55
}
58
56
59
57
def get_supported_openai_params (self ) -> List [str ]:
60
- return ["dimensions" ]
58
+ return ["dimensions" , "encoding_format" ]
61
59
62
- def map_openai_params (
63
- self , non_default_params : dict , optional_params : dict
64
- ) -> dict :
60
+ def map_openai_params (self , non_default_params : dict , optional_params : dict ) -> dict :
65
61
for k , v in non_default_params .items ():
66
62
if k == "dimensions" :
67
63
optional_params ["dimensions" ] = v
64
+ elif k == "encoding_format" :
65
+ # Map OpenAI encoding_format to AWS embeddingTypes
66
+ if v == "float" :
67
+ optional_params ["embeddingTypes" ] = ["float" ]
68
+ elif v == "base64" :
69
+ # base64 maps to binary format in AWS
70
+ optional_params ["embeddingTypes" ] = ["binary" ]
71
+ else :
72
+ # For any other encoding format, default to float
73
+ optional_params ["embeddingTypes" ] = ["float" ]
68
74
return optional_params
69
75
70
- def _transform_request (
71
- self , input : str , inference_params : dict
72
- ) -> AmazonTitanV2EmbeddingRequest :
76
+ def _transform_request (self , input : str , inference_params : dict ) -> AmazonTitanV2EmbeddingRequest :
73
77
return AmazonTitanV2EmbeddingRequest (inputText = input , ** inference_params ) # type: ignore
74
78
75
- def _transform_response (
76
- self , response_list : List [dict ], model : str
77
- ) -> EmbeddingResponse :
79
+ def _transform_response (self , response_list : List [dict ], model : str ) -> EmbeddingResponse :
78
80
total_prompt_tokens = 0
79
81
80
82
transformed_responses : List [Embedding ] = []
81
83
for index , response in enumerate (response_list ):
82
84
_parsed_response = AmazonTitanV2EmbeddingResponse (** response ) # type: ignore
85
+
86
+ # According to AWS docs, embeddingsByType is always present
87
+ # If binary was requested (encoding_format="base64"), use binary data
88
+ # Otherwise, use float data from embeddingsByType or fallback to embedding field
89
+ embedding_data : Union [List [float ], List [int ]]
90
+
91
+ if ("embeddingsByType" in _parsed_response and
92
+ "binary" in _parsed_response ["embeddingsByType" ]):
93
+ # Use binary data if available (for encoding_format="base64")
94
+ embedding_data = _parsed_response ["embeddingsByType" ]["binary" ]
95
+ elif ("embeddingsByType" in _parsed_response and
96
+ "float" in _parsed_response ["embeddingsByType" ]):
97
+ # Use float data from embeddingsByType
98
+ embedding_data = _parsed_response ["embeddingsByType" ]["float" ]
99
+ elif "embedding" in _parsed_response :
100
+ # Fallback to legacy embedding field
101
+ embedding_data = _parsed_response ["embedding" ]
102
+ else :
103
+ raise ValueError (f"No embedding data found in response: { response } " )
104
+
83
105
transformed_responses .append (
84
106
Embedding (
85
- embedding = _parsed_response [ "embedding" ] ,
107
+ embedding = embedding_data ,
86
108
index = index ,
87
109
object = "embedding" ,
88
110
)
0 commit comments