Skip to content

Commit 6488769

Browse files
committed
single trace log prep for ch
2 parents f0e820c + e23f33f commit 6488769

File tree

4 files changed

+46
-6
lines changed

4 files changed

+46
-6
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
3+
from dotenv import load_dotenv
4+
from openai import OpenAI
5+
6+
from parea import Parea
7+
8+
load_dotenv()
9+
10+
client = OpenAI()
11+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
12+
p.wrap_openai_client(client)
13+
14+
response = client.chat.completions.create(
15+
model="gpt-4",
16+
messages=[
17+
{"role": "user", "content": "Make up {{number}} people. Some {abc}: def"},
18+
],
19+
template_inputs={"number": "three"},
20+
)
21+
print(response.choices[0].message.content)

parea/wrapper/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Callable, Dict, List, Optional, Union
22

33
import json
4+
import re
45
import sys
56
from functools import lru_cache, wraps
67

@@ -336,3 +337,16 @@ def convert_openai_raw_stream_to_log(content: list, tools: dict, data: dict, tra
336337

337338
def convert_openai_raw_to_log(r: dict, data: dict):
338339
log_in_thread(_process_response, {"response": ChatCompletion(**r), "model_inputs": data, "trace_id": get_current_trace_id()})
340+
341+
342+
def safe_format_template_to_prompt(_template: str, **kwargs) -> str:
343+
"""Replaces langchain.prompts.PromptTemplate.format in a safe manner.
344+
345+
Only variables '{{...}}' will be replaced, not any in '{...}'
346+
"""
347+
348+
def replace(match):
349+
var_name = match.group(1)
350+
return str(kwargs.get(var_name, match.group(0)))
351+
352+
return re.sub(r"{{(\w+)}}", replace, _template)

parea/wrapper/wrapper.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from parea.helpers import timezone_aware_now
1515
from parea.schemas.models import TraceLog, UpdateTraceScenario
1616
from parea.utils.trace_utils import call_eval_funcs_then_log, fill_trace_data, trace_context, trace_data
17-
from parea.wrapper.utils import skip_decorator_if_func_in_stack
17+
from parea.wrapper.utils import safe_format_template_to_prompt, skip_decorator_if_func_in_stack
1818

1919
logger = logging.getLogger()
2020

@@ -72,13 +72,18 @@ def _get_decorator(self, unwrapped_func: Callable, original_func: Callable):
7272
else:
7373
return self.sync_decorator(original_func)
7474

75-
def _init_trace(self) -> Tuple[str, datetime, contextvars.Token]:
75+
def _init_trace(self, kwargs) -> Tuple[str, datetime, contextvars.Token]:
7676
start_time = timezone_aware_now()
7777
trace_id = str(uuid4())
7878

7979
new_trace_context = trace_context.get() + [trace_id]
8080
token = trace_context.set(new_trace_context)
8181

82+
if template_inputs := kwargs.pop("template_inputs", None):
83+
for m in kwargs["messages"] or []:
84+
if isinstance(m, dict) and "content" in m:
85+
m["content"] = safe_format_template_to_prompt(m["content"], **template_inputs)
86+
8287
if TURN_OFF_PAREA_LOGGING:
8388
return trace_id, start_time, token
8489
try:
@@ -93,7 +98,7 @@ def _init_trace(self) -> Tuple[str, datetime, contextvars.Token]:
9398
metadata=None,
9499
target=None,
95100
tags=None,
96-
inputs={},
101+
inputs=template_inputs,
97102
experiment_uuid=os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None),
98103
)
99104

@@ -109,7 +114,7 @@ def _init_trace(self) -> Tuple[str, datetime, contextvars.Token]:
109114
def async_decorator(self, orig_func: Callable) -> Callable:
110115
@functools.wraps(orig_func)
111116
async def wrapper(*args, **kwargs):
112-
trace_id, start_time, context_token = self._init_trace()
117+
trace_id, start_time, context_token = self._init_trace(kwargs)
113118
response = None
114119
exception = None
115120
error = None
@@ -141,7 +146,7 @@ async def wrapper(*args, **kwargs):
141146
def sync_decorator(self, orig_func: Callable) -> Callable:
142147
@functools.wraps(orig_func)
143148
def wrapper(*args, **kwargs):
144-
trace_id, start_time, context_token = self._init_trace()
149+
trace_id, start_time, context_token = self._init_trace(kwargs)
145150
response = None
146151
error = None
147152
cache_hit = False

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
66
[tool.poetry]
77
name = "parea-ai"
88
packages = [{ include = "parea" }]
9-
version = "0.2.131"
9+
version = "0.2.132"
1010
description = "Parea python sdk"
1111
readme = "README.md"
1212
authors = ["joel-parea-ai <[email protected]>"]

0 commit comments

Comments
 (0)