Skip to content

Commit a6e4d5d

Browse files
committed
Merge branch 'main' into PAI-449-improve-sdk-trace-log-system-architecture-pys
2 parents c75d7b6 + 9ed9c19 commit a6e4d5d

17 files changed

+299
-59
lines changed

.github/workflows/build.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
runs-on: ubuntu-latest
88
strategy:
99
matrix:
10-
python-version: [ "3.9" ]
10+
python-version: [ "3.8" ]
1111

1212
steps:
1313
- uses: actions/[email protected]
@@ -33,9 +33,9 @@ jobs:
3333
run: |
3434
make check-codestyle
3535
36-
# - name: Run tests
37-
# run: |
38-
# make test
36+
- name: Run tests
37+
run: |
38+
make test
3939
4040
# - name: Run safety checks
4141
# run: |

parea/__init__.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
run the following command: ```bash pip install parea ```.
1010
"""
1111
import sys
12-
from importlib import metadata as importlib_metadata
1312

13+
from parea.api_client import get_version
1414
from parea.cache import InMemoryCache
1515
from parea.client import Parea
1616
from parea.experiment.cli import experiment as _experiment_cli
@@ -22,14 +22,6 @@
2222
from parea.wrapper.openai_raw_api_tracer import aprocess_stream_and_yield, process_stream_and_yield
2323
from parea.wrapper.utils import convert_openai_raw_to_log
2424

25-
26-
def get_version() -> str:
27-
try:
28-
return importlib_metadata.version(__name__)
29-
except importlib_metadata.PackageNotFoundError: # pragma: no cover
30-
return "unknown"
31-
32-
3325
version: str = get_version()
3426

3527

parea/api_client.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
from typing import Any, AsyncIterable, Callable, Dict, Optional
1+
from typing import Any, AsyncIterable, Callable, Dict, List, Optional
22

33
import asyncio
44
import json
55
import os
66
import time
77
from functools import wraps
8+
from importlib import metadata as importlib_metadata
89

910
import httpx
1011
from dotenv import load_dotenv
1112

1213
load_dotenv()
1314

14-
MAX_RETRIES = 7
15+
MAX_RETRIES = 8
1516
BACKOFF_FACTOR = 0.5
1617

1718

@@ -60,6 +61,7 @@ class HTTPClient:
6061
_instance = None
6162
base_url = os.getenv("PAREA_BASE_URL", "https://parea-ai-backend-us-9ac16cdbc7a7b006.onporter.run/api/parea/v1")
6263
api_key = None
64+
integrations: List[str] = []
6365

6466
def __new__(cls, *args, **kwargs):
6567
if cls._instance is None:
@@ -71,6 +73,20 @@ def __new__(cls, *args, **kwargs):
7173
def set_api_key(self, api_key: str):
7274
self.api_key = api_key
7375

76+
def add_integration(self, integration: str):
77+
if integration not in self.integrations:
78+
self.integrations.append(integration)
79+
80+
def _get_headers(self, api_key: Optional[str] = None) -> Dict[str, str]:
81+
headers = {
82+
"x-api-key": self.api_key or api_key,
83+
"x-sdk-version": get_version(),
84+
"x-sdk-language": "python",
85+
}
86+
if self.integrations:
87+
headers["x-sdk-integrations"] = ",".join(self.integrations)
88+
return headers
89+
7490
@retry_on_502
7591
def request(
7692
self,
@@ -83,7 +99,7 @@ def request(
8399
"""
84100
Makes an HTTP request to the specified endpoint.
85101
"""
86-
headers = {"x-api-key": self.api_key} if self.api_key else api_key
102+
headers = self._get_headers(api_key=api_key)
87103
try:
88104
response = self.sync_client.request(method, endpoint, json=data, headers=headers, params=params)
89105
response.raise_for_status()
@@ -106,7 +122,7 @@ async def request_async(
106122
"""
107123
Makes an asynchronous HTTP request to the specified endpoint.
108124
"""
109-
headers = {"x-api-key": self.api_key} if self.api_key else api_key
125+
headers = self._get_headers(api_key=api_key)
110126
try:
111127
response = await self.async_client.request(method, endpoint, json=data, headers=headers, params=params)
112128
response.raise_for_status()
@@ -128,7 +144,7 @@ def stream_request(
128144
"""
129145
Makes a streaming HTTP request to the specified endpoint, yielding chunks of data.
130146
"""
131-
headers = {"x-api-key": self.api_key} if self.api_key else api_key
147+
headers = self._get_headers(api_key=api_key)
132148
try:
133149
with self.sync_client.stream(method, endpoint, json=data, headers=headers, params=params, timeout=None) as response:
134150
response.raise_for_status()
@@ -151,7 +167,7 @@ async def stream_request_async(
151167
"""
152168
Makes an asynchronous streaming HTTP request to the specified endpoint, yielding chunks of data.
153169
"""
154-
headers = {"x-api-key": self.api_key} if self.api_key else api_key
170+
headers = self._get_headers(api_key=api_key)
155171
try:
156172
async with self.async_client.stream(method, endpoint, json=data, headers=headers, params=params, timeout=None) as response:
157173
response.raise_for_status()
@@ -200,3 +216,10 @@ def parse_event_data(byte_data):
200216
except Exception as e:
201217
print(f"Error parsing event data: {e}")
202218
return None
219+
220+
221+
def get_version() -> str:
222+
try:
223+
return importlib_metadata.version("parea-ai")
224+
except importlib_metadata.PackageNotFoundError: # pragma: no cover
225+
return "unknown"

parea/client.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from parea.cache.cache import Cache
1515
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
1616
from parea.experiment.datasets import create_test_cases, create_test_collection
17-
from parea.helpers import gen_trace_id, serialize_metadata_values, structure_trace_log_from_api
17+
from parea.helpers import gen_trace_id, serialize_metadata_values, structure_trace_log_from_api, structure_trace_logs_from_api
1818
from parea.parea_logger import parea_logger
1919
from parea.schemas.models import (
2020
Completion,
@@ -25,11 +25,14 @@
2525
CreateTestCases,
2626
ExperimentSchema,
2727
ExperimentStatsSchema,
28+
ExperimentWithPinnedStatsSchema,
2829
FeedbackRequest,
2930
FinishExperimentRequestSchema,
31+
ListExperimentUUIDsFilters,
3032
ProjectSchema,
3133
TestCaseCollection,
3234
TraceLog,
35+
TraceLogFilters,
3336
UseDeployedPrompt,
3437
UseDeployedPromptResponse,
3538
)
@@ -50,6 +53,8 @@
5053
CREATE_COLLECTION_ENDPOINT = "/collection"
5154
ADD_TEST_CASES_ENDPOINT = "/testcases"
5255
GET_TRACE_LOG_ENDPOINT = "/trace_log/{trace_id}"
56+
LIST_EXPERIMENTS_ENDPOINT = "/experiments"
57+
GET_EXPERIMENT_LOGS_ENDPOINT = "/experiment/{experiment_uuid}/trace_logs"
5358

5459

5560
@define
@@ -71,19 +76,25 @@ def __attrs_post_init__(self):
7176
parea_logger.set_client(self._client)
7277
parea_logger.set_project_uuid(self.project_uuid)
7378

74-
def wrap_openai_client(self, client: "OpenAI") -> None:
79+
def wrap_openai_client(self, client: "OpenAI", integration: Optional[str] = None) -> None:
7580
"""Only necessary for instance client with OpenAI version >= 1.0.0"""
7681
from parea.wrapper import OpenAIWrapper
7782
from parea.wrapper.openai_beta_wrapper import BetaWrappers
7883

7984
OpenAIWrapper().init(log=logger_all_possible, cache=self.cache, module_client=client)
8085
BetaWrappers(client).init()
8186

82-
def wrap_anthropic_client(self, client: "Anthropic") -> None:
87+
if integration:
88+
self._client.add_integration(integration)
89+
90+
def wrap_anthropic_client(self, client: "Anthropic", integration: Optional[str] = None) -> None:
8391
from parea.wrapper.anthropic.anthropic import AnthropicWrapper
8492

8593
AnthropicWrapper().init(log=logger_all_possible, cache=self.cache, client=client)
8694

95+
if integration:
96+
self._client.add_integration(integration)
97+
8798
def auto_trace_openai_clients(self) -> None:
8899
import openai
89100

@@ -92,6 +103,10 @@ def auto_trace_openai_clients(self) -> None:
92103
openai.AzureOpenAI = patch_openai_client_classes(openai.AzureOpenAI, self)
93104
openai.AsyncAzureOpenAI = patch_openai_client_classes(openai.AsyncAzureOpenAI, self)
94105

106+
def integrate_with_sglang(self):
107+
self.auto_trace_openai_clients()
108+
self._client.add_integration("sglang")
109+
95110
def _add_project_uuid_to_data(self, data) -> dict:
96111
data_dict = asdict(data)
97112
data_dict["project_uuid"] = self._project.uuid
@@ -346,6 +361,22 @@ async def aget_trace_log(self, trace_id: str) -> TraceLog:
346361
response = await self._client.request_async("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
347362
return structure_trace_log_from_api(response.json())
348363

364+
def list_experiments(self, filter_conditions: Optional[ListExperimentUUIDsFilters] = ListExperimentUUIDsFilters()) -> List[ExperimentWithPinnedStatsSchema]:
365+
response = self._client.request("POST", LIST_EXPERIMENTS_ENDPOINT, data=asdict(filter_conditions))
366+
return structure(response.json(), List[ExperimentWithPinnedStatsSchema])
367+
368+
async def alist_experiments(self, filter_conditions: Optional[ListExperimentUUIDsFilters] = ListExperimentUUIDsFilters()) -> List[ExperimentWithPinnedStatsSchema]:
369+
response = await self._client.request_async("POST", LIST_EXPERIMENTS_ENDPOINT, data=asdict(filter_conditions))
370+
return structure(response.json(), List[ExperimentWithPinnedStatsSchema])
371+
372+
def get_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLog]:
373+
response = self._client.request("POST", GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment_uuid), data=asdict(filters))
374+
return structure_trace_logs_from_api(response.json())
375+
376+
async def aget_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLog]:
377+
response = await self._client.request_async("POST", GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment_uuid), data=asdict(filters))
378+
return structure_trace_logs_from_api(response.json())
379+
349380

350381
_initialized_parea_wrapper = False
351382

parea/cookbook/list_experiments.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
3+
from dotenv import load_dotenv
4+
5+
from parea import Parea
6+
from parea.schemas import ListExperimentUUIDsFilters
7+
8+
load_dotenv()
9+
10+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
11+
12+
experiments = p.list_experiments(ListExperimentUUIDsFilters(experiment_name_filter="Greeting"))
13+
print(f"Num. experiments: {len(experiments)}")
14+
trace_logs = p.get_experiment_trace_logs(experiments[0].uuid)
15+
print(f"Num. trace logs: {len(trace_logs)}")
16+
print(f"Trace log: {trace_logs[0]}")
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import json
2+
import os
3+
4+
from anthropic import Anthropic
5+
from dotenv import load_dotenv
6+
from openai import OpenAI
7+
8+
from parea import Parea, trace, trace_insert
9+
from parea.schemas import TraceLogImage
10+
11+
load_dotenv()
12+
13+
14+
oai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
15+
a_client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
16+
17+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
18+
p.wrap_openai_client(oai_client)
19+
p.wrap_anthropic_client(a_client)
20+
21+
22+
@trace
23+
def image_maker(query: str) -> str:
24+
response = oai_client.images.generate(prompt=query, model="dall-e-3")
25+
image_url = response.data[0].url
26+
caption = {"original_prompt": query, "revised_prompt": response.data[0].revised_prompt}
27+
trace_insert({"images": [TraceLogImage(url=image_url, caption=json.dumps(caption))]})
28+
return image_url
29+
30+
31+
from typing import Optional
32+
33+
import base64
34+
35+
import requests
36+
37+
38+
@trace
39+
def ask_vision(image_url: str) -> Optional[str]:
40+
image_data = requests.get(image_url).content
41+
base64_image = base64.b64encode(image_data).decode("utf-8")
42+
43+
response = a_client.messages.create(
44+
model="claude-3-haiku-20240307",
45+
messages=[
46+
{
47+
"role": "user",
48+
"content": [
49+
{
50+
"type": "image",
51+
"source": {
52+
"type": "base64",
53+
"media_type": "image/png",
54+
"data": base64_image,
55+
},
56+
},
57+
{"type": "text", "text": "What’s in this image?"},
58+
],
59+
}
60+
],
61+
max_tokens=300,
62+
)
63+
return response.content[0].text
64+
65+
66+
@trace
67+
def main(query: str) -> str:
68+
image_url = image_maker(query)
69+
return ask_vision(image_url)
70+
71+
72+
if __name__ == "__main__":
73+
result = main("A dog sitting comfortably on a chair")
74+
print(result)

parea/cookbook/tracing_with_images_open_ai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def image_maker(query: str) -> str:
3030
@trace
3131
def ask_vision(image_url: str) -> Optional[str]:
3232
response = client.chat.completions.create(
33-
model="gpt-4-vision-preview",
33+
model="gpt-4-turbo",
3434
messages=[
3535
{
3636
"role": "user",

parea/evals/rag/context_query_relevancy.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Callable, List, Optional
22

3-
from parea.evals.utils import call_openai, sent_tokenize
3+
from parea.evals.utils import call_openai, get_context, sent_tokenize
44
from parea.schemas.log import Log
55

66

@@ -27,13 +27,7 @@ def context_query_relevancy_factory(
2727
def context_query_relevancy(log: Log) -> float:
2828
"""Quantifies how much the retrieved context relates to the query."""
2929
question = log.inputs[question_field]
30-
if context_fields:
31-
context = "\n".join(log.inputs[context_field] for context_field in context_fields)
32-
else:
33-
if isinstance(log.output, list):
34-
context = "\n".join(log.output)
35-
else:
36-
context = str(log.output)
30+
context = get_context(log, context_fields)
3731

3832
extracted_sentences = call_openai(
3933
model=model,

parea/evals/rag/context_ranking_listwise.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Callable, List, Optional
22

3-
from parea.evals.utils import call_openai, ndcg
3+
from parea.evals.utils import call_openai, get_context, ndcg
44
from parea.schemas.log import Log
55

66

@@ -99,13 +99,7 @@ def progressive_reranking(query: str, contexts: List[str]) -> List[int]:
9999
def context_ranking(log: Log) -> float:
100100
"""Quantifies if the retrieved context is ranked by their relevancy by re-ranking the contexts."""
101101
question = log.inputs[question_field]
102-
if context_fields:
103-
contexts = [log.inputs[context_field] for context_field in context_fields]
104-
else:
105-
if isinstance(log.output, list):
106-
contexts = log.output
107-
else:
108-
contexts = [str(log.output)]
102+
contexts = get_context(log, context_fields, True)
109103

110104
reranked_indices = progressive_reranking(question, contexts)
111105

0 commit comments

Comments
 (0)