Skip to content

Commit 57945d6

Browse files
committed
feat: add instructor tracking
1 parent 6aa10cd commit 57945d6

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

parea/client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,25 @@ def __attrs_post_init__(self):
7676
parea_logger.set_client(self._client)
7777
parea_logger.set_project_uuid(self.project_uuid)
7878

79-
def wrap_openai_client(self, client: "OpenAI") -> None:
79+
def wrap_openai_client(self, client: "OpenAI", integration: Optional[str] = None) -> None:
8080
"""Only necessary for instance client with OpenAI version >= 1.0.0"""
8181
from parea.wrapper import OpenAIWrapper
8282
from parea.wrapper.openai_beta_wrapper import BetaWrappers
8383

8484
OpenAIWrapper().init(log=logger_all_possible, cache=self.cache, module_client=client)
8585
BetaWrappers(client).init()
8686

87-
def wrap_anthropic_client(self, client: "Anthropic") -> None:
87+
if integration:
88+
self._client.add_integration(integration)
89+
90+
def wrap_anthropic_client(self, client: "Anthropic", integration: Optional[str] = None) -> None:
8891
from parea.wrapper.anthropic.anthropic import AnthropicWrapper
8992

9093
AnthropicWrapper().init(log=logger_all_possible, cache=self.cache, client=client)
9194

95+
if integration:
96+
self._client.add_integration(integration)
97+
9298
def auto_trace_openai_clients(self) -> None:
9399
import openai
94100

parea/parea_logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def record_vendor_log(self, data: Dict[str, Any], vendor: TraceIntegrations) ->
6666
data["project_uuid"] = self._project_uuid
6767
if experiment_uuid := os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None):
6868
data["experiment_uuid"] = experiment_uuid
69-
self._client.add_integration('langchain')
69+
self._client.add_integration("langchain")
7070
self._client.request(
7171
"POST",
7272
VENDOR_LOG_ENDPOINT.format(vendor=vendor.value),
@@ -77,7 +77,7 @@ async def arecord_vendor_log(self, data: Dict[str, Any], vendor: TraceIntegratio
7777
data["project_uuid"] = self._project_uuid
7878
if experiment_uuid := os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None):
7979
data["experiment_uuid"] = experiment_uuid
80-
self._client.add_integration('langchain')
80+
self._client.add_integration("langchain")
8181
await self._client.request_async(
8282
"POST",
8383
VENDOR_LOG_ENDPOINT.format(vendor=vendor.value),

0 commit comments

Comments
 (0)