1- from typing import Any
1+ from typing import Any , Optional
22
33from llama_index .core .base .embeddings .base import Embedding
44from llama_index .core .embeddings import BaseEmbedding
@@ -21,22 +21,27 @@ class Embedding:
2121 def __init__ (
2222 self ,
2323 tool : BaseTool ,
24- adapter_instance_id : str ,
24+ adapter_instance_id : Optional [ str ] = None ,
2525 usage_kwargs : dict [Any , Any ] = {},
2626 ):
2727 self ._tool = tool
2828 self ._adapter_instance_id = adapter_instance_id
29- self ._embedding_instance : BaseEmbedding = self ._get_embedding ()
30- self ._length : int = self ._get_embedding_length ()
31-
32- self ._usage_kwargs = usage_kwargs .copy ()
33- self ._usage_kwargs ["adapter_instance_id" ] = adapter_instance_id
34- platform_api_key = self ._tool .get_env_or_die (ToolEnv .PLATFORM_API_KEY )
35- CallbackManager .set_callback_manager (
36- platform_api_key = platform_api_key ,
37- model = self ._embedding_instance ,
38- kwargs = self ._usage_kwargs ,
39- )
29+ self ._embedding_instance : BaseEmbedding = None
30+ self ._length : int = None
31+ self ._usage_kwargs = usage_kwargs
32+ self ._initialise ()
33+
34+ def _initialise (self ):
35+ if self ._adapter_instance_id :
36+ self ._embedding_instance = self ._get_embedding ()
37+ self ._length : int = self ._get_embedding_length ()
38+ self ._usage_kwargs ["adapter_instance_id" ] = self ._adapter_instance_id
39+ platform_api_key = self ._tool .get_env_or_die (ToolEnv .PLATFORM_API_KEY )
40+ CallbackManager .set_callback (
41+ platform_api_key = platform_api_key ,
42+ model = self ._embedding_instance ,
43+ kwargs = self ._usage_kwargs ,
44+ )
4045
4146 def _get_embedding (self ) -> BaseEmbedding :
4247 """Gets an instance of LlamaIndex's embedding object.
@@ -48,6 +53,10 @@ def _get_embedding(self) -> BaseEmbedding:
4853 BaseEmbedding: Embedding instance
4954 """
5055 try :
56+ if not self ._adapter_instance_id :
57+ raise EmbeddingError (
58+ "Adapter instance ID not set. " "Initialisation failed"
59+ )
5160 embedding_config_data = ToolAdapter .get_adapter_config (
5261 self ._tool , self ._adapter_instance_id
5362 )
@@ -79,9 +88,27 @@ def _get_embedding_length(self) -> int:
7988 embedding_dimension = len (embedding_list )
8089 return embedding_dimension
8190
82- @deprecated ("Use the new class Embedding" )
91+ def get_class_name (self ) -> str :
92+ """Gets the class name of the Llama Index Embedding.
93+
94+ Args:
95+ NA
96+
97+ Returns:
98+ Class name
99+ """
100+ return self ._embedding_instance .class_name ()
101+
102+ @deprecated ("Use Embedding instead of ToolEmbedding" )
83103 def get_embedding_length (self , embedding : BaseEmbedding ) -> int :
84- return self ._get_embedding_length (embedding )
104+ return self ._get_embedding_length ()
105+
106+ @deprecated ("Use Embedding instead of ToolEmbedding" )
107+ def get_embedding (self , adapter_instance_id : str ) -> BaseEmbedding :
108+ if not self ._embedding_instance :
109+ self ._adapter_instance_id = adapter_instance_id
110+ self ._initialise ()
111+ return self ._embedding_instance
85112
86113
87114# Legacy
0 commit comments