Skip to content

Commit 1635700

Browse files
committed
feat: extract tool output from openai-agents sdk
Signed-off-by: CormickKneey <[email protected]>
1 parent a48cd12 commit 1635700

File tree

1 file changed

+68
-23
lines changed

1 file changed

+68
-23
lines changed

areal/experimental/openai/client.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ async def create(
106106
if extra_body is None:
107107
extra_body = {}
108108
# Convert messages to prompt format
109-
tools = tools if tools is not NOT_GIVEN else None
109+
tools = tools if not is_omitted(tools) else None
110110
if self.chat_template_type == "hf":
111111
prompt_token_ids = self.tokenizer.apply_chat_template(
112112
messages_list,
@@ -128,23 +128,23 @@ async def create(
128128
f"Unsupported chat_template_type {self.chat_template_type}"
129129
)
130130

131-
temp = 1.0 if temperature is NOT_GIVEN else (temperature or 0.0)
131+
temp = 1.0 if is_omitted(temperature) else (temperature or 0.0)
132132
max_new_tokens = 512
133-
if max_tokens is not NOT_GIVEN and max_tokens is not None:
133+
if not is_omitted(max_tokens):
134134
max_new_tokens = max_tokens - len(prompt_token_ids)
135135
if max_new_tokens <= 0:
136136
raise RuntimeError(
137137
"max_tokens must be greater than the number of prompt tokens"
138138
)
139-
if max_completion_tokens is not NOT_GIVEN and max_completion_tokens is not None:
139+
if not is_omitted(max_completion_tokens):
140140
max_new_tokens = min(max_new_tokens, max_completion_tokens)
141141

142-
top_p_val = 1.0 if top_p is NOT_GIVEN else (top_p or 1.0)
143-
stop_tokens = None if stop is NOT_GIVEN else stop
142+
top_p_val = 1.0 if is_omitted(top_p) else (top_p or 1.0)
143+
stop_tokens = None if is_omitted(stop) else stop
144144
if stop_tokens is not None and not isinstance(stop_tokens, list):
145145
stop_tokens = [stop_tokens]
146146

147-
if frequency_penalty is NOT_GIVEN or frequency_penalty is None:
147+
if is_omitted(frequency_penalty):
148148
frequency_penalty = 0.0
149149

150150
# Create generation config
@@ -165,7 +165,7 @@ async def create(
165165
input_ids=prompt_token_ids,
166166
gconfig=gconfig,
167167
rid=str(uuid.uuid4()),
168-
metadata=metadata if metadata is not NOT_GIVEN else {},
168+
metadata=metadata if not is_omitted(metadata) else {},
169169
tokenizer=self.tokenizer,
170170
)
171171

@@ -215,7 +215,7 @@ async def create(
215215
),
216216
)
217217

218-
if store is NOT_GIVEN or store:
218+
if is_omitted(store) or store:
219219
# Cache the completion with its input messages
220220
self._cache[completion_id] = InteractionWithTokenLogpReward(
221221
completion=deepcopy(chat_completion),
@@ -269,13 +269,34 @@ async def create(
269269

270270
# Build a simple messages list compatible with tokenizer chat template
271271
messages_list: list[dict] = []
272-
if instructions is not NOT_GIVEN and instructions is not None:
272+
if not is_omitted(instructions):
273273
messages_list = [
274274
{"role": "system", "content": instructions},
275275
]
276-
if input is NOT_GIVEN or input is None:
276+
if is_omitted(input):
277277
raise ValueError("input is required for Responses.create")
278278

279+
def _convert_tool_output_format(item: dict) -> dict:
280+
"""Convert custom tool output format to standard chat template format.
281+
282+
Converts from: {'call_id': ..., 'output': ..., 'type': 'function_call_output'}
283+
To: {'role': 'tool', 'content': ..., 'tool_call_id': ...}
284+
"""
285+
if (
286+
isinstance(item, dict)
287+
and "output" in item
288+
and item.get("type") == "function_call_output"
289+
):
290+
converted = {
291+
"role": "tool",
292+
"content": item["output"],
293+
}
294+
# Add tool_call_id if present
295+
if "call_id" in item:
296+
converted["tool_call_id"] = item["call_id"]
297+
return converted
298+
return item
299+
279300
def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
280301
messages_list = []
281302
if "content" in item:
@@ -286,13 +307,17 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
286307
elif isinstance(item["content"], Iterable):
287308
for content in item["content"]:
288309
if isinstance(content, dict):
289-
messages_list.append(deepcopy(content))
310+
# Convert tool output format if needed
311+
converted = _convert_tool_output_format(content)
312+
messages_list.append(deepcopy(converted))
290313
else:
291314
raise ValueError("Unsupported content format")
292315
else:
293316
raise ValueError("Unsupported input item format")
294317
else:
295-
messages_list.append(deepcopy(item))
318+
# Convert tool output format if needed
319+
converted = _convert_tool_output_format(item)
320+
messages_list.append(deepcopy(converted))
296321
return messages_list
297322

298323
if isinstance(input, str):
@@ -309,7 +334,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
309334
)
310335

311336
# Apply chat template
312-
tools = list(tools) if tools is not NOT_GIVEN else None
337+
tools = list(tools) if not is_omitted(tools) else None
313338
if self.chat_template_type == "hf":
314339
prompt_token_ids = self.tokenizer.apply_chat_template(
315340
messages_list,
@@ -332,10 +357,10 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
332357
)
333358

334359
# Map sampling params
335-
temp = 1.0 if temperature is NOT_GIVEN else (temperature or 0.0)
336-
top_p_val = 1.0 if top_p is NOT_GIVEN else (top_p or 1.0)
360+
temp = 1.0 if is_omitted(temperature) else (temperature or 0.0)
361+
top_p_val = 1.0 if is_omitted(top_p) else (top_p or 1.0)
337362
max_new_tokens = 512
338-
if max_output_tokens is not NOT_GIVEN and max_output_tokens is not None:
363+
if not is_omitted(max_output_tokens):
339364
max_new_tokens = max_output_tokens
340365

341366
stop = kwargs.get("stop", None)
@@ -359,7 +384,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
359384
input_ids=prompt_token_ids,
360385
gconfig=gconfig,
361386
rid=str(uuid.uuid4()),
362-
metadata=metadata if metadata is not NOT_GIVEN else {},
387+
metadata=metadata if not is_omitted(metadata) else {},
363388
tokenizer=self.tokenizer,
364389
)
365390

@@ -369,7 +394,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
369394

370395
# Parse tool calls.
371396
tool_calls = None
372-
if tool_choice != "none" and tools:
397+
if not is_omitted(tool_choice) and tool_choice != "none" and tools:
373398
tool_calls, output_text, engine_resp.stop_reason = process_tool_calls(
374399
output_text,
375400
tools,
@@ -420,14 +445,14 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
420445
created_at=current_time,
421446
error=None,
422447
incomplete_details=None,
423-
instructions=None if instructions is NOT_GIVEN else instructions,
424-
metadata=None if metadata is NOT_GIVEN else metadata,
448+
instructions=None if is_omitted(instructions) else instructions,
449+
metadata=None if is_omitted(metadata) else metadata,
425450
model="None",
426451
object="response",
427452
output=resp_output,
428453
parallel_tool_calls=False,
429454
temperature=temp,
430-
tool_choice=tool_choice if tool_choice is not NOT_GIVEN else "none",
455+
tool_choice=tool_choice if not is_omitted(tool_choice) else "none",
431456
tools=tools,
432457
top_p=top_p_val,
433458
background=None,
@@ -453,7 +478,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
453478
response=deepcopy(response),
454479
model_response=engine_resp, # Should not deepcopy because of tokenizer
455480
input_data=(
456-
deepcopy(input) if input is not NOT_GIVEN else ""
481+
deepcopy(input) if not is_omitted(input) else ""
457482
), # Store a copy of the input data
458483
chat_template_type=self.chat_template_type,
459484
)
@@ -751,3 +776,23 @@ def export_responses(
751776
"export_responses is deprecated. Please use export_interactions instead."
752777
)
753778
return self.export_interactions(style)
779+
780+
781+
def is_omitted(value) -> bool:
782+
"""Check if a value is NOT_GIVEN or Omit type or None."""
783+
if value is NOT_GIVEN or value is None:
784+
return True
785+
# Use isinstance for type safety and robustness
786+
# Check for common omitted types from OpenAI SDK
787+
try:
788+
from openai import Omit
789+
790+
if isinstance(value, Omit):
791+
return True
792+
except ImportError:
793+
pass
794+
795+
# Fallback for other omit types
796+
if hasattr(value, "__class__"):
797+
return value.__class__.__name__ in ("NotGiven", "Omit")
798+
return False

0 commit comments

Comments
 (0)