Skip to content

Commit cfcda67

Browse files
authored
Merge pull request #1037 from parea-ai/PAI-1442-openai-structured-outputs
openai-structured-outputs
2 parents 8f4604a + 33a5374 commit cfcda67

File tree

8 files changed

+250
-19
lines changed

8 files changed

+250
-19
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import os
2+
3+
from dotenv import load_dotenv
4+
from openai import OpenAI
5+
from pydantic import BaseModel
6+
7+
from parea import Parea
8+
9+
load_dotenv()
10+
11+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
12+
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
13+
p.wrap_openai_client(client)
14+
15+
16+
class CalendarEvent(BaseModel):
17+
name: str
18+
date: str
19+
participants: list[str]
20+
21+
22+
def with_pydantic():
23+
completion = client.beta.chat.completions.parse(
24+
model="gpt-4o-2024-08-06",
25+
messages=[
26+
{"role": "system", "content": "Extract the event information."},
27+
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
28+
],
29+
response_format=CalendarEvent,
30+
)
31+
event = completion.choices[0].message.parsed
32+
print(event)
33+
34+
35+
def with_json_schema():
36+
response = client.chat.completions.create(
37+
model="gpt-4o-2024-08-06",
38+
messages=[
39+
{"role": "system", "content": "You are a helpful math tutor. Guide the user through the solution step by step."},
40+
{"role": "user", "content": "how can I solve 8x + 7 = -23"},
41+
],
42+
response_format={
43+
"type": "json_schema",
44+
"json_schema": {
45+
"name": "math_response",
46+
"schema": {
47+
"type": "object",
48+
"properties": {
49+
"steps": {
50+
"type": "array",
51+
"items": {
52+
"type": "object",
53+
"properties": {"explanation": {"type": "string"}, "output": {"type": "string"}},
54+
"required": ["explanation", "output"],
55+
"additionalProperties": False,
56+
},
57+
},
58+
"final_answer": {"type": "string"},
59+
},
60+
"required": ["steps", "final_answer"],
61+
"additionalProperties": False,
62+
},
63+
"strict": True,
64+
},
65+
},
66+
)
67+
print(response.choices[0].message.content)
68+
69+
70+
def with_tools():
71+
tools = [
72+
{
73+
"type": "function",
74+
"function": {
75+
"name": "get_delivery_date",
76+
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
77+
"parameters": {
78+
"type": "object",
79+
"properties": {
80+
"order_id": {
81+
"type": "string",
82+
"description": "The customer's order ID.",
83+
},
84+
},
85+
"required": ["order_id"],
86+
"additionalProperties": False,
87+
},
88+
},
89+
"strict": True,
90+
}
91+
]
92+
93+
messages = [
94+
{"role": "system", "content": "You are a helpful customer support assistant. Use the supplied tools to assist the user."},
95+
{"role": "user", "content": "Hi, can you tell me the delivery date for my order with id 5?"},
96+
]
97+
98+
response = client.chat.completions.create(
99+
model="gpt-4o-2024-08-06",
100+
messages=messages,
101+
tools=tools,
102+
)
103+
print(response.choices[0].message.tool_calls)
104+
105+
106+
if __name__ == "__main__":
107+
with_pydantic()
108+
with_json_schema()
109+
with_tools()

parea/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ def str2bool(v):
146146
"completion": 15.0,
147147
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 128000},
148148
},
149+
"gpt-4o-2024-08-06": {
150+
"prompt": 5.0,
151+
"completion": 15.0,
152+
"token_limit": {"max_completion_tokens": 4096, "max_prompt_tokens": 128000},
153+
},
149154
"gpt-4o-mini": {
150155
"prompt": 0.15,
151156
"completion": 0.6,

parea/utils/trace_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def log_in_thread(target_func: Callable, data: Dict[str, Any]):
4949
logging_thread.start()
5050

5151

52-
def merge(old, new):
52+
def merge(old, new, key=None):
53+
if key == "error" and old:
54+
return json_dumps([old, new])
5355
if isinstance(old, dict) and isinstance(new, dict):
5456
return dict(ChainMap(new, old))
5557
if isinstance(old, list) and isinstance(new, list):
@@ -112,7 +114,7 @@ def trace_insert(data: Dict[str, Any], trace_id: Optional[str] = None):
112114
return
113115
for key, new_value in data.items():
114116
existing_value = current_trace_data.__getattribute__(key)
115-
current_trace_data.__setattr__(key, merge(existing_value, new_value) if existing_value else new_value)
117+
current_trace_data.__setattr__(key, merge(existing_value, new_value, key) if existing_value else new_value)
116118
except Exception as e:
117119
logger.debug(f"Error occurred inserting data into trace log, {e}", exc_info=e)
118120

parea/utils/universal_encoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ def handle_dspy_response(self, obj) -> Any:
8181
else:
8282
return None
8383

84+
def handle_openai_not_given(self, obj) -> Any:
85+
try:
86+
from openai import NotGiven
87+
except ImportError:
88+
return None
89+
90+
if isinstance(obj, NotGiven):
91+
return {"not_given": None}
92+
return None
93+
8494
def default(self, obj: Any):
8595
if isinstance(obj, str):
8696
return obj
@@ -116,6 +126,8 @@ def default(self, obj: Any):
116126
return obj.to_dict(orient="records")
117127
elif dspy_response := self.handle_dspy_response(obj):
118128
return dspy_response
129+
elif is_openai_not_given := self.handle_openai_not_given(obj):
130+
return is_openai_not_given["not_given"]
119131
elif callable(obj):
120132
try:
121133
return f"<callable {obj.__name__}>"

parea/wrapper/openai/openai.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ def get_original_methods(self, module_client=openai):
5858
original_methods = {"chat.completions.create": module_client.chat.completions.create}
5959
except openai.OpenAIError:
6060
original_methods = {}
61+
62+
try:
63+
latest_methods = {"beta.chat.completions.parse": module_client.beta.chat.completions.parse}
64+
original_methods.update(latest_methods)
65+
except Exception:
66+
pass
67+
6168
return list(original_methods.keys())
6269

6370
def init(self, log: Callable, cache: Cache = None, module_client=openai):
@@ -103,7 +110,7 @@ def resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any],
103110
trace_data.get()[trace_id].output_tokens = output_tokens
104111
trace_data.get()[trace_id].total_tokens = total_tokens
105112
trace_data.get()[trace_id].cost = _compute_cost(input_tokens, output_tokens, model)
106-
trace_data.get()[trace_id].output = output
113+
trace_data.get()[trace_id].output = json_dumps(output) if not isinstance(output, str) else output
107114
return response
108115

109116
def gen_resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response, final_log):
@@ -269,7 +276,13 @@ def _kwargs_to_llm_configuration(kwargs):
269276

270277
@staticmethod
271278
def _get_output(result: Any, model: Optional[str] = None) -> str:
272-
if not isinstance(result, OpenAIObject) and isinstance(result, dict):
279+
try:
280+
from openai.types.chat import ParsedChatCompletion, ParsedChatCompletionMessage
281+
except ImportError:
282+
ParsedChatCompletion = None
283+
ParsedChatCompletionMessage = None
284+
285+
if not isinstance(result, (OpenAIObject, ParsedChatCompletion)) and isinstance(result, dict):
273286
result = convert_to_openai_object(
274287
{
275288
"choices": [
@@ -282,7 +295,9 @@ def _get_output(result: Any, model: Optional[str] = None) -> str:
282295
}
283296
)
284297
response_message = result.choices[0].message
285-
if not response_message.get("content", None) if is_old_openai else not response_message.content:
298+
if isinstance(response_message, ParsedChatCompletionMessage):
299+
completion = response_message.parsed.model_dump_json() if response_message.parsed else ""
300+
elif not response_message.get("content", None) if is_old_openai else not response_message.content:
286301
completion = OpenAIWrapper._format_function_call(response_message)
287302
else:
288303
completion = response_message.content

parea/wrapper/utils.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from functools import lru_cache, wraps
77

88
import tiktoken
9+
from openai import NotGiven
910
from openai import __version__ as openai_version
11+
from pydantic._internal._model_construction import ModelMetaclass
1012

1113
from parea.constants import ALL_NON_AZURE_MODELS_INFO, AZURE_MODEL_INFO, TURN_OFF_PAREA_EVAL_LOGGING
1214
from parea.parea_logger import parea_logger
@@ -220,9 +222,12 @@ def clean_json_string(s):
220222

221223
def _resolve_functions(kwargs):
222224
if "functions" in kwargs:
223-
return kwargs.get("functions", [])
225+
f = kwargs.get("functions", [])
226+
return None if isinstance(f, NotGiven) else f
224227
elif "tools" in kwargs:
225228
tools = kwargs["tools"]
229+
if isinstance(tools, NotGiven):
230+
return None
226231
if isinstance(tools, list):
227232
return [d.get("function", {}) for d in tools]
228233

@@ -234,19 +239,27 @@ def _resolve_functions(kwargs):
234239
def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs:
235240
functions = _resolve_functions(kwargs)
236241
function_call_default = "auto" if functions else None
242+
function_call = kwargs.get("function_call", function_call_default) or kwargs.get("tool_choice", function_call_default)
243+
response_format = kwargs.get("response_format", None)
244+
response_format = {"type": "json_schema", "json_schema": str(response_format)} if isinstance(response_format, ModelMetaclass) else response_format
245+
temp = kwargs.get("temperature", 1.0)
246+
max_length = kwargs.get("max_tokens", None)
247+
top_p = kwargs.get("top_p", 1.0)
248+
frequency_penalty = kwargs.get("frequency_penalty", 0.0)
249+
presence_penalty = kwargs.get("presence_penalty", 0.0)
237250
return LLMInputs(
238251
model=model or kwargs.get("model", None),
239252
provider="openai",
240253
messages=_convert_oai_messages(kwargs.get("messages", None)),
241254
functions=functions,
242-
function_call=kwargs.get("function_call", function_call_default) or kwargs.get("tool_choice", function_call_default),
255+
function_call=None if isinstance(function_call, NotGiven) else function_call,
243256
model_params=ModelParams(
244-
temp=kwargs.get("temperature", 1.0),
245-
max_length=kwargs.get("max_tokens", None),
246-
top_p=kwargs.get("top_p", 1.0),
247-
frequency_penalty=kwargs.get("frequency_penalty", 0.0),
248-
presence_penalty=kwargs.get("presence_penalty", 0.0),
249-
response_format=kwargs.get("response_format", None),
257+
temp=None if isinstance(temp, NotGiven) else temp,
258+
max_length=None if isinstance(max_length, NotGiven) else max_length,
259+
top_p=None if isinstance(top_p, NotGiven) else top_p,
260+
frequency_penalty=None if isinstance(frequency_penalty, NotGiven) else frequency_penalty,
261+
presence_penalty=None if isinstance(presence_penalty, NotGiven) else presence_penalty,
262+
response_format=response_format,
250263
),
251264
)
252265

@@ -302,7 +315,11 @@ def _compute_cost(prompt_tokens: int, completion_tokens: int, model: str) -> flo
302315

303316
def _process_response(response, model_inputs, trace_id):
304317
response_message = response.choices[0].message
305-
if response_message.content:
318+
if response_message.finish_reason == "content_filter":
319+
trace_insert({"error": "Error: The content was filtered due to policy violations."}, trace_id)
320+
if hasattr(response_message, "refusal"):
321+
completion = response_message.refusal
322+
elif response_message.content:
306323
completion = response_message.content
307324
else:
308325
completion = _format_function_call(response_message)

0 commit comments

Comments
 (0)