Skip to content

Commit 0fabae8

Browse files
Support run_id / token usage for prompt runs (#51)
* SDK refactoring * Add support for sdk x2text and ocr * indentation fi * Fixes for SDK refactoring * Remove unwanted line * Address review comments * Updated adapter version * Update src/unstract/sdk/index.py --------- Signed-off-by: Hari John Kuriakose <[email protected]>
1 parent e70ed7f commit 0fabae8

File tree

14 files changed

+1087
-640
lines changed

14 files changed

+1087
-640
lines changed

pdm.lock

Lines changed: 665 additions & 414 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies = [
1212
"python-magic~=0.4.27",
1313
"python-dotenv==1.0.0",
1414
# LLM Triad
15-
"unstract-adapters~=0.15.1",
15+
"unstract-adapters~=0.16.0",
1616
"llama-index==0.10.28",
1717
"tiktoken~=0.4.0",
1818
"transformers==4.37.0",

src/unstract/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.26.1"
1+
__version__ = "0.27.0"
22

33

44
def get_sdk_version():

src/unstract/sdk/audit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
import requests
24
from llama_index.core.callbacks import CBEventType, TokenCountingHandler
35

@@ -26,7 +28,7 @@ def push_usage_data(
2628
token_counter: TokenCountingHandler = None,
2729
model_name: str = "",
2830
event_type: CBEventType = None,
29-
**kwargs,
31+
kwargs: dict[Any, Any] = None,
3032
) -> None:
3133
"""Pushes the usage data to the platform service.
3234
@@ -84,9 +86,7 @@ def push_usage_data(
8486
headers = {"Authorization": f"Bearer {bearer_token}"}
8587

8688
try:
87-
response = requests.post(
88-
url, headers=headers, json=data, timeout=30
89-
)
89+
response = requests.post(url, headers=headers, json=data, timeout=30)
9090
if response.status_code != 200:
9191
self.stream_log(
9292
log=(

src/unstract/sdk/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,8 @@ class ToolSettingsKey:
146146
EMBEDDING_ADAPTER_ID = "embeddingAdapterId"
147147
VECTOR_DB_ADAPTER_ID = "vectorDbAdapterId"
148148
X2TEXT_ADAPTER_ID = "x2TextAdapterId"
149+
ADAPTER_INSTANCE_ID = "adapter_instance_id"
150+
EMBEDDING_DIMENSION = "embedding_dimension"
151+
RUN_ID = "run_id"
152+
WORKFLOW_ID = "workflow_id"
153+
EXECUTION_ID = "execution_id"

src/unstract/sdk/embedding.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,44 @@
1+
from typing import Any
2+
3+
from llama_index.core.base.embeddings.base import Embedding
14
from llama_index.core.embeddings import BaseEmbedding
5+
from typing_extensions import deprecated
26
from unstract.adapters.constants import Common
37
from unstract.adapters.embedding import adapters
48

59
from unstract.sdk.adapters import ToolAdapter
6-
from unstract.sdk.constants import LogLevel
7-
from unstract.sdk.exceptions import SdkError, ToolEmbeddingError
10+
from unstract.sdk.constants import LogLevel, ToolEnv
11+
from unstract.sdk.exceptions import EmbeddingError, SdkError
812
from unstract.sdk.tool.base import BaseTool
13+
from unstract.sdk.utils.callback_manager import CallbackManager
14+
915

16+
class Embedding:
17+
_TEST_SNIPPET = "Hello, I am Unstract"
18+
MAX_TOKENS = 1024 * 16
19+
embedding_adapters = adapters
1020

11-
class ToolEmbedding:
12-
__TEST_SNIPPET = "Hello, I am Unstract"
21+
def __init__(
22+
self,
23+
tool: BaseTool,
24+
adapter_instance_id: str,
25+
usage_kwargs: dict[Any, Any] = None,
26+
):
27+
self._tool = tool
28+
self._adapter_instance_id = adapter_instance_id
29+
self._embedding_instance: BaseEmbedding = self._get_embedding()
30+
self._length: int = self._get_embedding_length()
1331

14-
def __init__(self, tool: BaseTool):
15-
self.tool = tool
16-
self.max_tokens = 1024 * 16
17-
self.embedding_adapters = adapters
32+
self._usage_kwargs = usage_kwargs.copy()
33+
self._usage_kwargs["adapter_instance_id"] = adapter_instance_id
34+
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
35+
CallbackManager.set_callback_manager(
36+
platform_api_key=platform_api_key,
37+
model=self._embedding_instance,
38+
kwargs=self._usage_kwargs,
39+
)
1840

19-
def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
41+
def _get_embedding(self) -> BaseEmbedding:
2042
"""Gets an instance of LlamaIndex's embedding object.
2143
2244
Args:
@@ -27,7 +49,7 @@ def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
2749
"""
2850
try:
2951
embedding_config_data = ToolAdapter.get_adapter_config(
30-
self.tool, adapter_instance_id
52+
self._tool, self._adapter_instance_id
3153
)
3254
embedding_adapter_id = embedding_config_data.get(Common.ADAPTER_ID)
3355
if embedding_adapter_id not in self.embedding_adapters:
@@ -42,12 +64,25 @@ def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
4264
embedding_adapter_class = embedding_adapter(embedding_metadata)
4365
return embedding_adapter_class.get_embedding_instance()
4466
except Exception as e:
45-
self.tool.stream_log(
67+
self._tool.stream_log(
4668
log=f"Error getting embedding: {e}", level=LogLevel.ERROR
4769
)
48-
raise ToolEmbeddingError(f"Error getting embedding instance: {e}") from e
70+
raise EmbeddingError(f"Error getting embedding instance: {e}") from e
4971

50-
def get_embedding_length(self, embedding: BaseEmbedding) -> int:
51-
embedding_list = embedding._get_text_embedding(self.__TEST_SNIPPET)
72+
def get_query_embedding(self, query: str) -> Embedding:
73+
return self._embedding_instance.get_query_embedding(query)
74+
75+
def _get_embedding_length(self) -> int:
76+
embedding_list = self._embedding_instance._get_text_embedding(
77+
self._TEST_SNIPPET
78+
)
5279
embedding_dimension = len(embedding_list)
5380
return embedding_dimension
81+
82+
@deprecated("Use the new class Embedding")
83+
def get_embedding_length(self, embedding: BaseEmbedding) -> int:
84+
return self._get_embedding_length(embedding)
85+
86+
87+
# Legacy
88+
ToolEmbedding = Embedding

src/unstract/sdk/exceptions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ def __init__(self, message: str = ""):
1717
super().__init__(message)
1818

1919

20-
class ToolLLMError(SdkError):
20+
class LLMError(SdkError):
2121
DEFAULT_MESSAGE = "Error ocurred related to LLM"
2222

2323

24-
class ToolEmbeddingError(SdkError):
24+
class EmbeddingError(SdkError):
2525
DEFAULT_MESSAGE = "Error ocurred related to embedding"
2626

2727

28-
class ToolVectorDBError(SdkError):
28+
class VectorDBError(SdkError):
2929
DEFAULT_MESSAGE = "Error ocurred related to vector DB"
3030

3131

0 commit comments

Comments
 (0)