Skip to content

Commit e8979f5

Browse files
SDK backward compatibility changes (#55)
* Backward comaptibility * Make SDK backward compatible * Fix for run_id population * Address review comments * roll sdk version * Make param optiona; * Update src/unstract/sdk/vector_db.py Co-authored-by: Chandrasekharan M <[email protected]> Signed-off-by: Gayathri <[email protected]> --------- Signed-off-by: Gayathri <[email protected]> Co-authored-by: Chandrasekharan M <[email protected]>
1 parent 0a69375 commit e8979f5

File tree

10 files changed

+427
-190
lines changed

10 files changed

+427
-190
lines changed

src/unstract/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.30.0"
1+
__version__ = "0.31.0"
22

33

44
def get_sdk_version():

src/unstract/sdk/embedding.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Optional
22

33
from llama_index.core.base.embeddings.base import Embedding
44
from llama_index.core.embeddings import BaseEmbedding
@@ -21,22 +21,27 @@ class Embedding:
2121
def __init__(
2222
self,
2323
tool: BaseTool,
24-
adapter_instance_id: str,
24+
adapter_instance_id: Optional[str] = None,
2525
usage_kwargs: dict[Any, Any] = {},
2626
):
2727
self._tool = tool
2828
self._adapter_instance_id = adapter_instance_id
29-
self._embedding_instance: BaseEmbedding = self._get_embedding()
30-
self._length: int = self._get_embedding_length()
31-
32-
self._usage_kwargs = usage_kwargs.copy()
33-
self._usage_kwargs["adapter_instance_id"] = adapter_instance_id
34-
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
35-
CallbackManager.set_callback_manager(
36-
platform_api_key=platform_api_key,
37-
model=self._embedding_instance,
38-
kwargs=self._usage_kwargs,
39-
)
29+
self._embedding_instance: BaseEmbedding = None
30+
self._length: int = None
31+
self._usage_kwargs = usage_kwargs
32+
self._initialise()
33+
34+
def _initialise(self):
35+
if self._adapter_instance_id:
36+
self._embedding_instance = self._get_embedding()
37+
self._length: int = self._get_embedding_length()
38+
self._usage_kwargs["adapter_instance_id"] = self._adapter_instance_id
39+
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
40+
CallbackManager.set_callback(
41+
platform_api_key=platform_api_key,
42+
model=self._embedding_instance,
43+
kwargs=self._usage_kwargs,
44+
)
4045

4146
def _get_embedding(self) -> BaseEmbedding:
4247
"""Gets an instance of LlamaIndex's embedding object.
@@ -48,6 +53,10 @@ def _get_embedding(self) -> BaseEmbedding:
4853
BaseEmbedding: Embedding instance
4954
"""
5055
try:
56+
if not self._adapter_instance_id:
57+
raise EmbeddingError(
58+
"Adapter instance ID not set. " "Initialisation failed"
59+
)
5160
embedding_config_data = ToolAdapter.get_adapter_config(
5261
self._tool, self._adapter_instance_id
5362
)
@@ -79,9 +88,27 @@ def _get_embedding_length(self) -> int:
7988
embedding_dimension = len(embedding_list)
8089
return embedding_dimension
8190

82-
@deprecated("Use the new class Embedding")
91+
def get_class_name(self) -> str:
92+
"""Gets the class name of the Llama Index Embedding.
93+
94+
Args:
95+
NA
96+
97+
Returns:
98+
Class name
99+
"""
100+
return self._embedding_instance.class_name()
101+
102+
@deprecated("Use Embedding instead of ToolEmbedding")
83103
def get_embedding_length(self, embedding: BaseEmbedding) -> int:
84-
return self._get_embedding_length(embedding)
104+
return self._get_embedding_length()
105+
106+
@deprecated("Use Embedding instead of ToolEmbedding")
107+
def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
108+
if not self._embedding_instance:
109+
self._adapter_instance_id = adapter_instance_id
110+
self._initialise()
111+
return self._embedding_instance
85112

86113

87114
# Legacy

src/unstract/sdk/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,9 @@ class X2TextError(SdkError):
3333
DEFAULT_MESSAGE = "Error ocurred related to text extractor"
3434

3535

36+
class OCRError(SdkError):
37+
DEFAULT_MESSAGE = "Error ocurred related to OCR"
38+
39+
3640
class RateLimitError(SdkError):
3741
DEFAULT_MESSAGE = "Running into rate limit errors, please try again later"

0 commit comments

Comments
 (0)