Skip to content

Commit 623807c

Browse files
FEAT: Add Support for Public Indexing and Prompt Run Functionality (#74)
* Implemented SPS support for index method * Changes to support public calls * Changes to support public calls * Fixed sonar issues * Code optimization * Reverted index.py file to it's previous state * Fixed pre-commit issues * Code quality improvements and minor bug fixes * Fixed pre-commit issues * Fixed pre-commit issues * Added log message and optimized code --------- Co-authored-by: Gayathri <[email protected]>
1 parent 0c624b3 commit 623807c

File tree

8 files changed

+117
-22
lines changed

8 files changed

+117
-22
lines changed

src/unstract/sdk/adapters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import json
12
from typing import Any, Optional
23

34
import requests
45

56
from unstract.sdk.constants import AdapterKeys, LogLevel, ToolEnv
7+
from unstract.sdk.helper import SdkHelper
68
from unstract.sdk.platform import PlatformBase
79
from unstract.sdk.tool.base import BaseTool
810

@@ -88,6 +90,11 @@ def get_adapter_config(
8890
) -> Optional[dict[str, Any]]:
8991
"""Get adapter spec by the help of unstract DB tool.
9092
93+
This method first checks if the adapter_instance_id matches
94+
any of the public adapter keys. If it matches, the configuration
95+
is fetched from environment variables. Otherwise, it connects to the
96+
platform service to retrieve the configuration.
97+
9198
Args:
9299
adapter_instance_id (str): ID of the adapter instance
93100
tool (AbstractTool): Instance of AbstractTool
@@ -97,6 +104,15 @@ def get_adapter_config(
97104
Returns:
98105
Any: engine
99106
"""
107+
# Check if the adapter ID matches any public adapter keys
108+
if SdkHelper.is_public_adapter(
109+
adapter_id=adapter_instance_id
110+
):
111+
adapter_metadata_config = tool.get_env_or_die(
112+
adapter_instance_id
113+
)
114+
adapter_metadata = json.loads(adapter_metadata_config)
115+
return adapter_metadata
100116
platform_host = tool.get_env_or_die(ToolEnv.PLATFORM_HOST)
101117
platform_port = tool.get_env_or_die(ToolEnv.PLATFORM_PORT)
102118

src/unstract/sdk/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,10 @@ class ToolSettingsKey:
151151
RUN_ID = "run_id"
152152
WORKFLOW_ID = "workflow_id"
153153
EXECUTION_ID = "execution_id"
154+
155+
156+
class PublicAdapterKeys:
157+
PUBLIC_LLM_CONFIG = "PUBLIC_LLM_CONFIG"
158+
PUBLIC_EMBEDDING_CONFIG = "PUBLIC_EMBEDDING_CONFIG"
159+
PUBLIC_VECTOR_DB_CONFIG = "PUBLIC_VECTOR_DB_CONFIG"
160+
PUBLIC_X2TEXT_CONFIG = "PUBLIC_X2TEXT_CONFIG"

src/unstract/sdk/embedding.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from unstract.sdk.adapters import ToolAdapter
1010
from unstract.sdk.constants import LogLevel, ToolEnv
1111
from unstract.sdk.exceptions import EmbeddingError, SdkError
12+
from unstract.sdk.helper import SdkHelper
1213
from unstract.sdk.tool.base import BaseTool
1314
from unstract.sdk.utils.callback_manager import CallbackManager
1415

@@ -36,12 +37,16 @@ def _initialise(self):
3637
self._embedding_instance = self._get_embedding()
3738
self._length: int = self._get_embedding_length()
3839
self._usage_kwargs["adapter_instance_id"] = self._adapter_instance_id
39-
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
40-
CallbackManager.set_callback(
41-
platform_api_key=platform_api_key,
42-
model=self._embedding_instance,
43-
kwargs=self._usage_kwargs,
44-
)
40+
41+
if not SdkHelper.is_public_adapter(
42+
adapter_id=self._adapter_instance_id
43+
):
44+
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
45+
CallbackManager.set_callback(
46+
platform_api_key=platform_api_key,
47+
model=self._embedding_instance,
48+
kwargs=self._usage_kwargs,
49+
)
4550

4651
def _get_embedding(self) -> BaseEmbedding:
4752
"""Gets an instance of LlamaIndex's embedding object.
@@ -57,6 +62,7 @@ def _get_embedding(self) -> BaseEmbedding:
5762
raise EmbeddingError(
5863
"Adapter instance ID not set. " "Initialisation failed"
5964
)
65+
6066
embedding_config_data = ToolAdapter.get_adapter_config(
6167
self._tool, self._adapter_instance_id
6268
)

src/unstract/sdk/helper.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
import logging
2+
3+
from unstract.sdk.constants import PublicAdapterKeys
4+
5+
logger = logging.getLogger(__name__)
6+
17
class SdkHelper:
28
def __init__(self) -> None:
39
pass
@@ -16,3 +22,29 @@ def get_platform_base_url(platform_host: str, platform_port: str) -> str:
1622
if platform_host[-1] == "/":
1723
return f"{platform_host[:-1]}:{platform_port}"
1824
return f"{platform_host}:{platform_port}"
25+
26+
@staticmethod
27+
def is_public_adapter(adapter_id: str) -> bool:
28+
"""Check if the given adapter_id is one of the public adapter keys.
29+
30+
This method iterates over the attributes of the PublicAdapterKeys class
31+
and checks if the provided adapter_id matches any of the attribute values.
32+
33+
Args:
34+
adapter_id (str): The ID of the adapter to check.
35+
36+
Returns:
37+
bool: True if the adapter_id matches any public adapter key,
38+
False otherwise.
39+
"""
40+
try:
41+
for attr in dir(PublicAdapterKeys):
42+
if getattr(PublicAdapterKeys, attr) == adapter_id:
43+
return True
44+
return False
45+
except Exception as e:
46+
logger.warning(
47+
f"Unable to determine if adapter_id: {adapter_id}"
48+
f"is public or not: {str(e)}"
49+
)
50+
return False

src/unstract/sdk/llm.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from unstract.sdk.adapters import ToolAdapter
1515
from unstract.sdk.constants import LogLevel, ToolEnv
1616
from unstract.sdk.exceptions import LLMError, RateLimitError, SdkError
17+
from unstract.sdk.helper import SdkHelper
1718
from unstract.sdk.tool.base import BaseTool
1819
from unstract.sdk.utils.callback_manager import CallbackManager
1920

@@ -54,12 +55,16 @@ def _initialise(self):
5455
if self._adapter_instance_id:
5556
self._llm_instance = self._get_llm(self._adapter_instance_id)
5657
self._usage_kwargs["adapter_instance_id"] = self._adapter_instance_id
57-
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
58-
CallbackManager.set_callback(
59-
platform_api_key=platform_api_key,
60-
model=self._llm_instance,
61-
kwargs=self._usage_kwargs,
62-
)
58+
59+
if not SdkHelper.is_public_adapter(
60+
adapter_id=self._adapter_instance_id
61+
):
62+
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
63+
CallbackManager.set_callback(
64+
platform_api_key=platform_api_key,
65+
model=self._llm_instance,
66+
kwargs=self._usage_kwargs,
67+
)
6368

6469
def complete(
6570
self,
@@ -94,9 +99,11 @@ def _get_llm(self, adapter_instance_id: str) -> LlamaIndexLLM:
9499
try:
95100
if not self._adapter_instance_id:
96101
raise LLMError("Adapter instance ID not set. " "Initialisation failed")
102+
97103
llm_config_data = ToolAdapter.get_adapter_config(
98104
self._tool, self._adapter_instance_id
99105
)
106+
100107
llm_adapter_id = llm_config_data.get(Common.ADAPTER_ID)
101108
if llm_adapter_id not in self.llm_adapters:
102109
raise SdkError(f"LLM adapter not supported : " f"{llm_adapter_id}")

src/unstract/sdk/prompt.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
tool: BaseTool,
2020
prompt_host: str,
2121
prompt_port: str,
22+
is_public_call: bool = False,
2223
) -> None:
2324
"""
2425
Args:
@@ -28,13 +29,18 @@ def __init__(
2829
"""
2930
self.tool = tool
3031
self.base_url = SdkHelper.get_platform_base_url(prompt_host, prompt_port)
31-
self.bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
32+
self.is_public_call = is_public_call
33+
if not is_public_call:
34+
self.bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
3235

3336
def answer_prompt(
3437
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
3538
) -> dict[str, Any]:
39+
url_path = "answer-prompt"
40+
if self.is_public_call:
41+
url_path = "answer-prompt-public"
3642
return self._post_call(
37-
url_path="answer-prompt",
43+
url_path=url_path,
3844
payload=payload,
3945
params=params,
4046
)
@@ -85,14 +91,16 @@ def _post_call(
8591
"structure_output": "",
8692
}
8793
url: str = f"{self.base_url}/{url_path}"
88-
headers: dict[str, str] = {"Authorization": f"Bearer {self.bearer_token}"}
94+
headers: dict[str, str] = {}
95+
if not self.is_public_call:
96+
headers = {"Authorization": f"Bearer {self.bearer_token}"}
8997
response: Response = Response()
9098
try:
9199
response = requests.post(
92100
url=url,
93101
json=payload,
94-
headers=headers,
95102
params=params,
103+
headers=headers
96104
)
97105
response.raise_for_status()
98106
result["status"] = "OK"

src/unstract/sdk/vector_db.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from unstract.sdk.constants import LogLevel, ToolEnv
2020
from unstract.sdk.embedding import Embedding
2121
from unstract.sdk.exceptions import SdkError, VectorDBError
22+
from unstract.sdk.helper import SdkHelper
2223
from unstract.sdk.platform import PlatformHelper
2324
from unstract.sdk.tool.base import BaseTool
2425

@@ -83,9 +84,11 @@ def _get_vector_db(self) -> Union[BasePydanticVectorStore, VectorStore]:
8384
raise VectorDBError(
8485
"Adapter instance ID not set. Initialisation failed"
8586
)
87+
8688
vector_db_config = ToolAdapter.get_adapter_config(
8789
self._tool, self._adapter_instance_id
8890
)
91+
8992
vector_db_adapter_id = vector_db_config.get(Common.ADAPTER_ID)
9093
if vector_db_adapter_id not in self.vector_db_adapters:
9194
raise SdkError(
@@ -96,10 +99,13 @@ def _get_vector_db(self) -> Union[BasePydanticVectorStore, VectorStore]:
9699
Common.METADATA
97100
][Common.ADAPTER]
98101
vector_db_metadata = vector_db_config.get(Common.ADAPTER_METADATA)
99-
org = self._get_org_id()
100102
# Adding the collection prefix and embedding type
101103
# to the metadata
102-
vector_db_metadata[VectorDbConstants.VECTOR_DB_NAME] = org
104+
105+
if not SdkHelper.is_public_adapter(adapter_id=self._adapter_instance_id):
106+
org = self._get_org_id()
107+
vector_db_metadata[VectorDbConstants.VECTOR_DB_NAME] = org
108+
103109
vector_db_metadata[
104110
VectorDbConstants.EMBEDDING_DIMENSION
105111
] = self._embedding_dimension

src/unstract/sdk/x2txt.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,16 @@
1111
from unstract.sdk.adapters import ToolAdapter
1212
from unstract.sdk.constants import LogLevel
1313
from unstract.sdk.exceptions import X2TextError
14+
from unstract.sdk.helper import SdkHelper
1415
from unstract.sdk.tool.base import BaseTool
1516

1617

1718
class X2Text(metaclass=ABCMeta):
18-
def __init__(self, tool: BaseTool, adapter_instance_id: Optional[str] = None):
19+
def __init__(
20+
self,
21+
tool: BaseTool,
22+
adapter_instance_id: Optional[str] = None
23+
):
1924
self._tool = tool
2025
self._x2text_adapters = adapters
2126
self._adapter_instance_id = adapter_instance_id
@@ -32,9 +37,11 @@ def _get_x2text(self) -> X2TextAdapter:
3237
raise X2TextError(
3338
"Adapter instance ID not set. " "Initialisation failed"
3439
)
40+
3541
x2text_config = ToolAdapter.get_adapter_config(
3642
self._tool, self._adapter_instance_id
3743
)
44+
3845
x2text_adapter_id = x2text_config.get(Common.ADAPTER_ID)
3946
if x2text_adapter_id in self._x2text_adapters:
4047
x2text_adapter = self._x2text_adapters[x2text_adapter_id][
@@ -48,9 +55,15 @@ def _get_x2text(self) -> X2TextAdapter:
4855
x2text_metadata[
4956
X2TextConstants.X2TEXT_PORT
5057
] = self._tool.get_env_or_die(X2TextConstants.X2TEXT_PORT)
51-
x2text_metadata[
52-
X2TextConstants.PLATFORM_SERVICE_API_KEY
53-
] = self._tool.get_env_or_die(X2TextConstants.PLATFORM_SERVICE_API_KEY)
58+
59+
if not SdkHelper.is_public_adapter(
60+
adapter_id=self._adapter_instance_id
61+
):
62+
x2text_metadata[
63+
X2TextConstants.PLATFORM_SERVICE_API_KEY
64+
] = self._tool.get_env_or_die(
65+
X2TextConstants.PLATFORM_SERVICE_API_KEY
66+
)
5467

5568
self._x2text_instance = x2text_adapter(x2text_metadata)
5669

0 commit comments

Comments
 (0)