Skip to content

Commit a8edb5c

Browse files
committed
struc outputs support
1 parent d0d1d7f commit a8edb5c

File tree

6 files changed

+118
-25
lines changed

6 files changed

+118
-25
lines changed

cookbook/openai/tracing_with_openai_with_structured_output.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,80 @@ class CalendarEvent(BaseModel):
2828
)
2929

3030

31-
if __name__ == "__main__":
32-
event = completion.choices[0].message.parsed
33-
print(type(event))
34-
print(event)
31+
def main():
32+
response = client.chat.completions.create(
33+
model="gpt-4o-2024-08-06",
34+
messages=[
35+
{"role": "system", "content": "You are a helpful math tutor. Guide the user through the solution step by step."},
36+
{"role": "user", "content": "how can I solve 8x + 7 = -23"},
37+
],
38+
response_format={
39+
"type": "json_schema",
40+
"json_schema": {
41+
"name": "math_response",
42+
"schema": {
43+
"type": "object",
44+
"properties": {
45+
"steps": {
46+
"type": "array",
47+
"items": {
48+
"type": "object",
49+
"properties": {"explanation": {"type": "string"}, "output": {"type": "string"}},
50+
"required": ["explanation", "output"],
51+
"additionalProperties": False,
52+
},
53+
},
54+
"final_answer": {"type": "string"},
55+
},
56+
"required": ["steps", "final_answer"],
57+
"additionalProperties": False,
58+
},
59+
"strict": True,
60+
},
61+
},
62+
)
63+
64+
print(response.choices[0].message.content)
65+
66+
67+
def main2():
68+
tools = [
69+
{
70+
"type": "function",
71+
"function": {
72+
"name": "get_delivery_date",
73+
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
74+
"parameters": {
75+
"type": "object",
76+
"properties": {
77+
"order_id": {
78+
"type": "string",
79+
"description": "The customer's order ID.",
80+
},
81+
},
82+
"required": ["order_id"],
83+
"additionalProperties": False,
84+
},
85+
},
86+
"strict": True,
87+
}
88+
]
3589

90+
messages = [
91+
{"role": "system", "content": "You are a helpful customer support assistant. Use the supplied tools to assist the user."},
92+
{"role": "user", "content": "Hi, can you tell me the delivery date for my order with id 5?"},
93+
]
3694

37-
TraceLog(configuration=LLMInputs(model='gpt-4o-2024-08-06', provider='openai', model_params=ModelParams(temp=1.0, top_p=1.0, frequency_penalty=0.0, presence_penalty=0.0, max_length=None, response_format="<class '__main__.CalendarEvent'>", safe_prompt=None), messages=[{'role': 'system', 'content': 'Extract the event information.'}, {'role': 'user', 'content': 'Alice and Bob are going to a science fair on Friday.'}], history=None, functions=[], function_call=None), inputs=None, output='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}', target=None, latency=1.771622, time_to_first_token=None, input_tokens=32, output_tokens=17, total_tokens=49, cost=0.000415, scores=[], trace_id='e894b955-c844-49f5-8480-71b0841f10b5', parent_trace_id='e894b955-c844-49f5-8480-71b0841f10b5', root_trace_id='e894b955-c844-49f5-8480-71b0841f10b5', start_timestamp='2024-08-06T20:26:42.365024+00:00', organization_id=None, project_uuid=None, error=None, status='success', deployment_id=None, cache_hit=False, output_for_eval_metrics=None, evaluation_metric_names=[], apply_eval_frac=1.0, feedback_score=None, trace_name='llm-openai', children=['600b25b7-417a-4409-8f96-afe8dd8fe8cf'], end_timestamp='2024-08-06T20:26:44.136646+00:00', end_user_identifier=None, session_id=None, metadata=None, tags=None, experiment_uuid=None, images=[], comments=None, annotations=None, depth=0, execution_order=0)
38-
D {'configuration': {'model': 'gpt-4o-2024-08-06', 'provider': 'openai', 'model_params': {'temp': 1.0, 'top_p': 1.0, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, 'max_length': None, 'response_format': "<class '__main__.CalendarEvent'>", 'safe_prompt': None}, 'messages': [{'role': 'system', 'content': 'Extract the event information.'}, {'role': 'user', 'content': 'Alice and Bob are going to a science fair on Friday.'}], 'history': None, 'functions': [], 'function_call': None}, 'inputs': None, 'output': '{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}', 'target': None, 'latency': 1.771622, 'time_to_first_token': None, 'input_tokens': 32, 'output_tokens': 17, 'total_tokens': 49, 'cost': 0.000415, 'scores': [], 'trace_id': 'e894b955-c844-49f5-8480-71b0841f10b5', 'parent_trace_id': 'e894b955-c844-49f5-8480-71b0841f10b5', 'root_trace_id': 'e894b955-c844-49f5-8480-71b0841f10b5', 'start_timestamp': '2024-08-06T20:26:42.365024+00:00', 'organization_id': None, 'project_uuid': '1c4dfe49-bf84-11ee-92b3-3a9b36099f82', 'error': None, 'status': 'success', 'deployment_id': None, 'cache_hit': False, 'output_for_eval_metrics': None, 'evaluation_metric_names': [], 'apply_eval_frac': 1.0, 'feedback_score': None, 'trace_name': 'llm-openai', 'children': ['600b25b7-417a-4409-8f96-afe8dd8fe8cf'], 'end_timestamp': '2024-08-06T20:26:44.136646+00:00', 'end_user_identifier': None, 'session_id': None, 'metadata': None, 'tags': None, 'experiment_uuid': None, 'images': [], 'comments': None, 'annotations': None, 'depth': 0, 'execution_order': 0}
95+
response = client.chat.completions.create(
96+
model="gpt-4o-2024-08-06",
97+
messages=messages,
98+
tools=tools,
99+
)
100+
print(response.choices[0].message.tool_calls)
101+
102+
103+
if __name__ == "__main__":
104+
# event = completion.choices[0].message.parsed
105+
# print(event)
106+
# main()
107+
main2()

parea/parea_logger.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,8 @@ def update_log(self, data: UpdateLog) -> None:
5757
)
5858

5959
def record_log(self, data: TraceLog) -> None:
60-
print(data)
6160
data = serialize_metadata_values(data)
6261
data.project_uuid = self._get_project_uuid()
63-
d = asdict(data)
64-
print("D", d)
6562
self._client.request(
6663
"POST",
6764
LOG_ENDPOINT,

parea/utils/trace_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Generator, Iterator, List, Optional, Tuple
2-
31
import contextvars
42
import inspect
53
import json
@@ -11,6 +9,7 @@
119
from datetime import datetime
1210
from functools import wraps
1311
from random import random
12+
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Generator, Iterator, List, Optional, Tuple
1413

1514
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID, TURN_OFF_PAREA_EVAL_LOGGING
1615
from parea.helpers import gen_trace_id, is_logging_disabled, timezone_aware_now
@@ -49,7 +48,9 @@ def log_in_thread(target_func: Callable, data: Dict[str, Any]):
4948
logging_thread.start()
5049

5150

52-
def merge(old, new):
51+
def merge(old, new, key=None):
52+
if key == "error" and old:
53+
return json_dumps([old, new])
5354
if isinstance(old, dict) and isinstance(new, dict):
5455
return dict(ChainMap(new, old))
5556
if isinstance(old, list) and isinstance(new, list):
@@ -112,7 +113,7 @@ def trace_insert(data: Dict[str, Any], trace_id: Optional[str] = None):
112113
return
113114
for key, new_value in data.items():
114115
existing_value = current_trace_data.__getattribute__(key)
115-
current_trace_data.__setattr__(key, merge(existing_value, new_value) if existing_value else new_value)
116+
current_trace_data.__setattr__(key, merge(existing_value, new_value, key) if existing_value else new_value)
116117
except Exception as e:
117118
logger.debug(f"Error occurred inserting data into trace log, {e}", exc_info=e)
118119

parea/utils/universal_encoder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,18 @@ def handle_dspy_response(self, obj) -> Any:
8080
else:
8181
return None
8282

83+
def handle_openai_not_given(self, obj) -> Any:
84+
try:
85+
from openai import NotGiven
86+
except ImportError:
87+
return None
88+
89+
from openai import NotGiven
90+
91+
if isinstance(obj, NotGiven):
92+
return {"not_given": None}
93+
return None
94+
8395
def default(self, obj: Any):
8496
if isinstance(obj, str):
8597
return obj
@@ -115,6 +127,8 @@ def default(self, obj: Any):
115127
return obj.to_dict(orient="records")
116128
elif dspy_response := self.handle_dspy_response(obj):
117129
return dspy_response
130+
elif is_openai_not_given := self.handle_openai_not_given(obj):
131+
return is_openai_not_given["not_given"]
118132
elif callable(obj):
119133
try:
120134
return f"<callable {obj.__name__}>"

parea/wrapper/utils.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Callable, Dict, List, Optional, Union
66

77
import tiktoken
8-
from openai import __version__ as openai_version
8+
from openai import __version__ as openai_version, NotGiven
99
from parea.constants import ALL_NON_AZURE_MODELS_INFO, AZURE_MODEL_INFO, TURN_OFF_PAREA_EVAL_LOGGING
1010
from parea.parea_logger import parea_logger
1111
from parea.schemas.log import LLMInputs, Message, ModelParams, Role
@@ -219,9 +219,12 @@ def clean_json_string(s):
219219

220220
def _resolve_functions(kwargs):
221221
if "functions" in kwargs:
222-
return kwargs.get("functions", [])
222+
f = kwargs.get("functions", [])
223+
return None if isinstance(f, NotGiven) else f
223224
elif "tools" in kwargs:
224225
tools = kwargs["tools"]
226+
if isinstance(tools, NotGiven):
227+
return None
225228
if isinstance(tools, list):
226229
return [d.get("function", {}) for d in tools]
227230

@@ -233,21 +236,26 @@ def _resolve_functions(kwargs):
233236
def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs:
234237
functions = _resolve_functions(kwargs)
235238
function_call_default = "auto" if functions else None
239+
function_call = kwargs.get("function_call", function_call_default) or kwargs.get("tool_choice", function_call_default)
236240
response_format = kwargs.get("response_format", None)
237-
response_format = str(response_format) if isinstance(response_format, ModelMetaclass) else response_format
238-
241+
response_format = {"type": "json_schema", "json_schema": str(response_format)} if isinstance(response_format, ModelMetaclass) else response_format
242+
temp = kwargs.get("temperature", 1.0)
243+
max_length = kwargs.get("max_tokens", None)
244+
top_p = kwargs.get("top_p", 1.0)
245+
frequency_penalty = kwargs.get("frequency_penalty", 0.0)
246+
presence_penalty = kwargs.get("presence_penalty", 0.0)
239247
return LLMInputs(
240248
model=model or kwargs.get("model", None),
241249
provider="openai",
242250
messages=_convert_oai_messages(kwargs.get("messages", None)),
243251
functions=functions,
244-
function_call=kwargs.get("function_call", function_call_default) or kwargs.get("tool_choice", function_call_default),
252+
function_call=None if isinstance(function_call, NotGiven) else function_call,
245253
model_params=ModelParams(
246-
temp=kwargs.get("temperature", 1.0),
247-
max_length=kwargs.get("max_tokens", None),
248-
top_p=kwargs.get("top_p", 1.0),
249-
frequency_penalty=kwargs.get("frequency_penalty", 0.0),
250-
presence_penalty=kwargs.get("presence_penalty", 0.0),
254+
temp=None if isinstance(temp, NotGiven) else temp,
255+
max_length=None if isinstance(max_length, NotGiven) else max_length,
256+
top_p=None if isinstance(top_p, NotGiven) else top_p,
257+
frequency_penalty=None if isinstance(frequency_penalty, NotGiven) else frequency_penalty,
258+
presence_penalty=None if isinstance(presence_penalty, NotGiven) else presence_penalty,
251259
response_format=response_format,
252260
),
253261
)
@@ -304,7 +312,11 @@ def _compute_cost(prompt_tokens: int, completion_tokens: int, model: str) -> flo
304312

305313
def _process_response(response, model_inputs, trace_id):
306314
response_message = response.choices[0].message
307-
if response_message.content:
315+
if response_message.finish_reason == "content_filter":
316+
trace_insert({"error": "Error: The content was filtered due to policy violations."}, trace_id)
317+
if hasattr(response_message, "refusal"):
318+
completion = response_message.refusal
319+
elif response_message.content:
308320
completion = response_message.content
309321
else:
310322
completion = _format_function_call(response_message)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
66
[tool.poetry]
77
name = "parea-ai"
88
packages = [{ include = "parea" }]
9-
version = "0.2.192"
9+
version = "0.2.193"
1010
description = "Parea python sdk"
1111
readme = "README.md"
1212
authors = ["joel-parea-ai <[email protected]>"]

0 commit comments

Comments
 (0)