|
14 | 14 | from unstract.sdk.adapters import ToolAdapter |
15 | 15 | from unstract.sdk.constants import LogLevel, ToolEnv |
16 | 16 | from unstract.sdk.exceptions import LLMError, RateLimitError, SdkError |
| 17 | +from unstract.sdk.helper import SdkHelper |
17 | 18 | from unstract.sdk.tool.base import BaseTool |
18 | 19 | from unstract.sdk.utils.callback_manager import CallbackManager |
19 | 20 |
|
@@ -54,12 +55,16 @@ def _initialise(self): |
54 | 55 | if self._adapter_instance_id: |
55 | 56 | self._llm_instance = self._get_llm(self._adapter_instance_id) |
56 | 57 | self._usage_kwargs["adapter_instance_id"] = self._adapter_instance_id |
57 | | - platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY) |
58 | | - CallbackManager.set_callback( |
59 | | - platform_api_key=platform_api_key, |
60 | | - model=self._llm_instance, |
61 | | - kwargs=self._usage_kwargs, |
62 | | - ) |
| 58 | + |
| 59 | + if not SdkHelper.is_public_adapter( |
| 60 | + adapter_id=self._adapter_instance_id |
| 61 | + ): |
| 62 | + platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY) |
| 63 | + CallbackManager.set_callback( |
| 64 | + platform_api_key=platform_api_key, |
| 65 | + model=self._llm_instance, |
| 66 | + kwargs=self._usage_kwargs, |
| 67 | + ) |
63 | 68 |
|
64 | 69 | def complete( |
65 | 70 | self, |
@@ -94,9 +99,11 @@ def _get_llm(self, adapter_instance_id: str) -> LlamaIndexLLM: |
94 | 99 | try: |
95 | 100 | if not self._adapter_instance_id: |
96 | 101 | raise LLMError("Adapter instance ID not set. " "Initialisation failed") |
| 102 | + |
97 | 103 | llm_config_data = ToolAdapter.get_adapter_config( |
98 | 104 | self._tool, self._adapter_instance_id |
99 | 105 | ) |
| 106 | + |
100 | 107 | llm_adapter_id = llm_config_data.get(Common.ADAPTER_ID) |
101 | 108 | if llm_adapter_id not in self.llm_adapters: |
102 | 109 | raise SdkError(f"LLM adapter not supported : " f"{llm_adapter_id}") |
|
0 commit comments