1+ from typing import Any
2+
3+ from llama_index .core .base .embeddings .base import Embedding
14from llama_index .core .embeddings import BaseEmbedding
5+ from typing_extensions import deprecated
26from unstract .adapters .constants import Common
37from unstract .adapters .embedding import adapters
48
59from unstract .sdk .adapters import ToolAdapter
6- from unstract .sdk .constants import LogLevel
7- from unstract .sdk .exceptions import SdkError , ToolEmbeddingError
10+ from unstract .sdk .constants import LogLevel , ToolEnv
11+ from unstract .sdk .exceptions import EmbeddingError , SdkError
812from unstract .sdk .tool .base import BaseTool
13+ from unstract .sdk .utils .callback_manager import CallbackManager
14+
915
16+ class Embedding :
17+ _TEST_SNIPPET = "Hello, I am Unstract"
18+ MAX_TOKENS = 1024 * 16
19+ embedding_adapters = adapters
1020
11- class ToolEmbedding :
12- __TEST_SNIPPET = "Hello, I am Unstract"
21+ def __init__ (
22+ self ,
23+ tool : BaseTool ,
24+ adapter_instance_id : str ,
25+ usage_kwargs : dict [Any , Any ] = None ,
26+ ):
27+ self ._tool = tool
28+ self ._adapter_instance_id = adapter_instance_id
29+ self ._embedding_instance : BaseEmbedding = self ._get_embedding ()
30+ self ._length : int = self ._get_embedding_length ()
1331
14- def __init__ (self , tool : BaseTool ):
15- self .tool = tool
16- self .max_tokens = 1024 * 16
17- self .embedding_adapters = adapters
32+ self ._usage_kwargs = usage_kwargs .copy ()
33+ self ._usage_kwargs ["adapter_instance_id" ] = adapter_instance_id
34+ platform_api_key = self ._tool .get_env_or_die (ToolEnv .PLATFORM_API_KEY )
35+ CallbackManager .set_callback_manager (
36+ platform_api_key = platform_api_key ,
37+ model = self ._embedding_instance ,
38+ kwargs = self ._usage_kwargs ,
39+ )
1840
19- def get_embedding (self , adapter_instance_id : str ) -> BaseEmbedding :
41+ def _get_embedding (self ) -> BaseEmbedding :
2042 """Gets an instance of LlamaIndex's embedding object.
2143
2244 Args:
@@ -27,7 +49,7 @@ def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
2749 """
2850 try :
2951 embedding_config_data = ToolAdapter .get_adapter_config (
30- self .tool , adapter_instance_id
52+ self ._tool , self . _adapter_instance_id
3153 )
3254 embedding_adapter_id = embedding_config_data .get (Common .ADAPTER_ID )
3355 if embedding_adapter_id not in self .embedding_adapters :
@@ -42,12 +64,25 @@ def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
4264 embedding_adapter_class = embedding_adapter (embedding_metadata )
4365 return embedding_adapter_class .get_embedding_instance ()
4466 except Exception as e :
45- self .tool .stream_log (
67+ self ._tool .stream_log (
4668 log = f"Error getting embedding: { e } " , level = LogLevel .ERROR
4769 )
48- raise ToolEmbeddingError (f"Error getting embedding instance: { e } " ) from e
70+ raise EmbeddingError (f"Error getting embedding instance: { e } " ) from e
4971
50- def get_embedding_length (self , embedding : BaseEmbedding ) -> int :
51- embedding_list = embedding ._get_text_embedding (self .__TEST_SNIPPET )
72+ def get_query_embedding (self , query : str ) -> Embedding :
73+ return self ._embedding_instance .get_query_embedding (query )
74+
75+ def _get_embedding_length (self ) -> int :
76+ embedding_list = self ._embedding_instance ._get_text_embedding (
77+ self ._TEST_SNIPPET
78+ )
5279 embedding_dimension = len (embedding_list )
5380 return embedding_dimension
81+
82+ @deprecated ("Use the new class Embedding" )
83+ def get_embedding_length (self , embedding : BaseEmbedding ) -> int :
84+ return self ._get_embedding_length (embedding )
85+
86+
87+ # Legacy
88+ ToolEmbedding = Embedding
0 commit comments