Skip to content

Commit e23f33f

Browse files
authored
Merge pull request #762 from parea-ai/fix-json-templated-messages
fix: capture inputs to auto-traced LLM calls with json fields
2 parents f000e02 + 3399c94 commit e23f33f

File tree

4 files changed

+38
-3
lines changed

4 files changed

+38
-3
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
3+
from dotenv import load_dotenv
4+
from openai import OpenAI
5+
6+
from parea import Parea
7+
8+
load_dotenv()
9+
10+
client = OpenAI()
11+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
12+
p.wrap_openai_client(client)
13+
14+
response = client.chat.completions.create(
15+
model="gpt-4",
16+
messages=[
17+
{"role": "user", "content": "Make up {{number}} people. Some {abc}: def"},
18+
],
19+
template_inputs={"number": "three"},
20+
)
21+
print(response.choices[0].message.content)

parea/wrapper/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Callable, Dict, List, Optional, Union
22

33
import json
4+
import re
45
import sys
56
from functools import lru_cache, wraps
67

@@ -338,3 +339,16 @@ def convert_openai_raw_stream_to_log(content: list, tools: dict, data: dict, tra
338339

339340
def convert_openai_raw_to_log(r: dict, data: dict):
340341
log_in_thread(_process_response, {"response": ChatCompletion(**r), "model_inputs": data, "trace_id": get_current_trace_id()})
342+
343+
344+
def safe_format_template_to_prompt(_template: str, **kwargs) -> str:
345+
"""Replaces langchain.prompts.PromptTemplate.format in a safe manner.
346+
347+
Only variables '{{...}}' will be replaced, not any in '{...}'
348+
"""
349+
350+
def replace(match):
351+
var_name = match.group(1)
352+
return str(kwargs.get(var_name, match.group(0)))
353+
354+
return re.sub(r"{{(\w+)}}", replace, _template)

parea/wrapper/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from parea.helpers import timezone_aware_now
1515
from parea.schemas.models import TraceLog, UpdateTraceScenario
1616
from parea.utils.trace_utils import call_eval_funcs_then_log, fill_trace_data, trace_context, trace_data
17-
from parea.wrapper.utils import skip_decorator_if_func_in_stack
17+
from parea.wrapper.utils import safe_format_template_to_prompt, skip_decorator_if_func_in_stack
1818

1919
logger = logging.getLogger()
2020

@@ -82,7 +82,7 @@ def _init_trace(self, kwargs) -> Tuple[str, datetime, contextvars.Token]:
8282
if template_inputs := kwargs.pop("template_inputs", None):
8383
for m in kwargs["messages"] or []:
8484
if isinstance(m, dict) and "content" in m:
85-
m["content"] = m["content"].format(**template_inputs)
85+
m["content"] = safe_format_template_to_prompt(m["content"], **template_inputs)
8686

8787
if TURN_OFF_PAREA_LOGGING:
8888
return trace_id, start_time, token

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
66
[tool.poetry]
77
name = "parea-ai"
88
packages = [{ include = "parea" }]
9-
version = "0.2.131"
9+
version = "0.2.132"
1010
description = "Parea python sdk"
1111
readme = "README.md"
1212
authors = ["joel-parea-ai <[email protected]>"]

0 commit comments

Comments
 (0)