Skip to content

Commit 9e4d99a

Browse files
committed
feat: extract tool output from openai-agents sdk
Signed-off-by: CormickKneey <[email protected]>
1 parent a48cd12 commit 9e4d99a

File tree

1 file changed

+53
-8
lines changed

1 file changed

+53
-8
lines changed

areal/experimental/openai/client.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ async def create(
165165
input_ids=prompt_token_ids,
166166
gconfig=gconfig,
167167
rid=str(uuid.uuid4()),
168-
metadata=metadata if metadata is not NOT_GIVEN else {},
168+
metadata=metadata if not is_omitted(metadata) else {},
169169
tokenizer=self.tokenizer,
170170
)
171171

@@ -276,6 +276,27 @@ async def create(
276276
if input is NOT_GIVEN or input is None:
277277
raise ValueError("input is required for Responses.create")
278278

279+
def _convert_tool_output_format(item: dict) -> dict:
280+
"""Convert custom tool output format to standard chat template format.
281+
282+
Converts from: {'call_id': ..., 'output': ..., 'type': 'function_call_output'}
283+
To: {'role': 'tool', 'content': ..., 'tool_call_id': ...}
284+
"""
285+
if (
286+
isinstance(item, dict)
287+
and "output" in item
288+
and item.get("type") == "function_call_output"
289+
):
290+
converted = {
291+
"role": "tool",
292+
"content": item["output"],
293+
}
294+
# Add tool_call_id if present
295+
if "call_id" in item:
296+
converted["tool_call_id"] = item["call_id"]
297+
return converted
298+
return item
299+
279300
def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
280301
messages_list = []
281302
if "content" in item:
@@ -286,13 +307,17 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
286307
elif isinstance(item["content"], Iterable):
287308
for content in item["content"]:
288309
if isinstance(content, dict):
289-
messages_list.append(deepcopy(content))
310+
# Convert tool output format if needed
311+
converted = _convert_tool_output_format(content)
312+
messages_list.append(deepcopy(converted))
290313
else:
291314
raise ValueError("Unsupported content format")
292315
else:
293316
raise ValueError("Unsupported input item format")
294317
else:
295-
messages_list.append(deepcopy(item))
318+
# Convert tool output format if needed
319+
converted = _convert_tool_output_format(item)
320+
messages_list.append(deepcopy(converted))
296321
return messages_list
297322

298323
if isinstance(input, str):
@@ -335,7 +360,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
335360
temp = 1.0 if temperature is NOT_GIVEN else (temperature or 0.0)
336361
top_p_val = 1.0 if top_p is NOT_GIVEN else (top_p or 1.0)
337362
max_new_tokens = 512
338-
if max_output_tokens is not NOT_GIVEN and max_output_tokens is not None:
363+
if not is_omitted(max_output_tokens):
339364
max_new_tokens = max_output_tokens
340365

341366
stop = kwargs.get("stop", None)
@@ -359,7 +384,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
359384
input_ids=prompt_token_ids,
360385
gconfig=gconfig,
361386
rid=str(uuid.uuid4()),
362-
metadata=metadata if metadata is not NOT_GIVEN else {},
387+
metadata=metadata if not is_omitted(metadata) else {},
363388
tokenizer=self.tokenizer,
364389
)
365390

@@ -420,14 +445,14 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
420445
created_at=current_time,
421446
error=None,
422447
incomplete_details=None,
423-
instructions=None if instructions is NOT_GIVEN else instructions,
424-
metadata=None if metadata is NOT_GIVEN else metadata,
448+
instructions=None if is_omitted(instructions) else instructions,
449+
metadata=None if is_omitted(metadata) else metadata,
425450
model="None",
426451
object="response",
427452
output=resp_output,
428453
parallel_tool_calls=False,
429454
temperature=temp,
430-
tool_choice=tool_choice if tool_choice is not NOT_GIVEN else "none",
455+
tool_choice=tool_choice if not is_omitted(tool_choice) else "none",
431456
tools=tools,
432457
top_p=top_p_val,
433458
background=None,
@@ -751,3 +776,23 @@ def export_responses(
751776
"export_responses is deprecated. Please use export_interactions instead."
752777
)
753778
return self.export_interactions(style)
779+
780+
781+
def is_omitted(value) -> bool:
782+
"""Check if a value is NOT_GIVEN or Omit type."""
783+
if value is NOT_GIVEN or value is None:
784+
return True
785+
# Use isinstance for type safety and robustness
786+
# Check for common omitted types from OpenAI SDK
787+
try:
788+
from openai.types import NOT_GIVEN as OpenAINotGiven
789+
790+
if isinstance(value, OpenAINotGiven):
791+
return True
792+
except ImportError:
793+
pass
794+
795+
# Fallback for other omit types
796+
if hasattr(value, "__class__"):
797+
return value.__class__.__name__ in ("NotGiven", "Omit")
798+
return False

0 commit comments

Comments
 (0)