Skip to content

Commit 16a4def

Browse files
CormickKneeyBruce-rl-hw
authored andcommitted
feat: extract tool output from openai-agents sdk (inclusionAI#507)
Signed-off-by: CormickKneey <[email protected]>
1 parent f1a8ab4 commit 16a4def

File tree

1 file changed

+80
-23
lines changed

1 file changed

+80
-23
lines changed

areal/experimental/openai/client.py

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from openai.types.chat import (
1616
ChatCompletion,
1717
ChatCompletionMessage,
18+
ChatCompletionToolMessageParam,
1819
ChatCompletionToolParam,
1920
)
2021
from openai.types.chat.chat_completion import Choice
@@ -106,7 +107,7 @@ async def create(
106107
if extra_body is None:
107108
extra_body = {}
108109
# Convert messages to prompt format
109-
tools = tools if tools is not NOT_GIVEN else None
110+
tools = tools if not is_omitted(tools) else None
110111
if self.chat_template_type == "hf":
111112
prompt_token_ids = self.tokenizer.apply_chat_template(
112113
messages_list,
@@ -128,23 +129,23 @@ async def create(
128129
f"Unsupported chat_template_type {self.chat_template_type}"
129130
)
130131

131-
temp = 1.0 if temperature is NOT_GIVEN else (temperature or 0.0)
132+
temp = 1.0 if is_omitted(temperature) else (temperature or 0.0)
132133
max_new_tokens = 512
133-
if max_tokens is not NOT_GIVEN and max_tokens is not None:
134+
if not is_omitted(max_tokens):
134135
max_new_tokens = max_tokens - len(prompt_token_ids)
135136
if max_new_tokens <= 0:
136137
raise RuntimeError(
137138
"max_tokens must be greater than the number of prompt tokens"
138139
)
139-
if max_completion_tokens is not NOT_GIVEN and max_completion_tokens is not None:
140+
if not is_omitted(max_completion_tokens):
140141
max_new_tokens = min(max_new_tokens, max_completion_tokens)
141142

142-
top_p_val = 1.0 if top_p is NOT_GIVEN else (top_p or 1.0)
143-
stop_tokens = None if stop is NOT_GIVEN else stop
143+
top_p_val = 1.0 if is_omitted(top_p) else (top_p or 1.0)
144+
stop_tokens = None if is_omitted(stop) else stop
144145
if stop_tokens is not None and not isinstance(stop_tokens, list):
145146
stop_tokens = [stop_tokens]
146147

147-
if frequency_penalty is NOT_GIVEN or frequency_penalty is None:
148+
if is_omitted(frequency_penalty):
148149
frequency_penalty = 0.0
149150

150151
# Create generation config
@@ -165,7 +166,7 @@ async def create(
165166
input_ids=prompt_token_ids,
166167
gconfig=gconfig,
167168
rid=str(uuid.uuid4()),
168-
metadata=metadata if metadata is not NOT_GIVEN else {},
169+
metadata=metadata if not is_omitted(metadata) else {},
169170
tokenizer=self.tokenizer,
170171
)
171172

@@ -215,7 +216,7 @@ async def create(
215216
),
216217
)
217218

218-
if store is NOT_GIVEN or store:
219+
if is_omitted(store) or store:
219220
# Cache the completion with its input messages
220221
self._cache[completion_id] = InteractionWithTokenLogpReward(
221222
completion=deepcopy(chat_completion),
@@ -269,13 +270,45 @@ async def create(
269270

270271
# Build a simple messages list compatible with tokenizer chat template
271272
messages_list: list[dict] = []
272-
if instructions is not NOT_GIVEN and instructions is not None:
273+
if not is_omitted(instructions):
273274
messages_list = [
274275
{"role": "system", "content": instructions},
275276
]
276-
if input is NOT_GIVEN or input is None:
277+
if is_omitted(input):
277278
raise ValueError("input is required for Responses.create")
278279

280+
def _convert_tool_output_format(
281+
item: dict,
282+
) -> ChatCompletionToolMessageParam | dict:
283+
"""Convert custom tool output format to standard chat template format.
284+
285+
Converts openai.types.responses.response_input_item_param.FunctionCallOutput
286+
to openai.types.chat.ChatCompletionToolMessageParam.
287+
288+
Args:
289+
item: Input dict, could be FunctionCallOutput from openai-agents SDK
290+
with format: {'call_id': str, 'output': str, 'type': 'function_call_output'}
291+
292+
Returns:
293+
ChatCompletionToolMessageParam (TypedDict) with format:
294+
{'role': 'tool', 'content': str, 'tool_call_id': str}
295+
or the original dict if conversion is not needed.
296+
"""
297+
if (
298+
isinstance(item, dict)
299+
and "output" in item
300+
and item.get("type") == "function_call_output"
301+
):
302+
converted = {
303+
"role": "tool",
304+
"content": item["output"],
305+
}
306+
# Add tool_call_id if present
307+
if "call_id" in item:
308+
converted["tool_call_id"] = item["call_id"]
309+
return converted
310+
return item
311+
279312
def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
280313
messages_list = []
281314
if "content" in item:
@@ -286,13 +319,17 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
286319
elif isinstance(item["content"], Iterable):
287320
for content in item["content"]:
288321
if isinstance(content, dict):
289-
messages_list.append(deepcopy(content))
322+
# Convert tool output format if needed
323+
converted = _convert_tool_output_format(content)
324+
messages_list.append(deepcopy(converted))
290325
else:
291326
raise ValueError("Unsupported content format")
292327
else:
293328
raise ValueError("Unsupported input item format")
294329
else:
295-
messages_list.append(deepcopy(item))
330+
# Convert tool output format if needed
331+
converted = _convert_tool_output_format(item)
332+
messages_list.append(deepcopy(converted))
296333
return messages_list
297334

298335
if isinstance(input, str):
@@ -309,7 +346,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
309346
)
310347

311348
# Apply chat template
312-
tools = list(tools) if tools is not NOT_GIVEN else None
349+
tools = list(tools) if not is_omitted(tools) else None
313350
if self.chat_template_type == "hf":
314351
prompt_token_ids = self.tokenizer.apply_chat_template(
315352
messages_list,
@@ -332,10 +369,10 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
332369
)
333370

334371
# Map sampling params
335-
temp = 1.0 if temperature is NOT_GIVEN else (temperature or 0.0)
336-
top_p_val = 1.0 if top_p is NOT_GIVEN else (top_p or 1.0)
372+
temp = 1.0 if is_omitted(temperature) else (temperature or 0.0)
373+
top_p_val = 1.0 if is_omitted(top_p) else (top_p or 1.0)
337374
max_new_tokens = 512
338-
if max_output_tokens is not NOT_GIVEN and max_output_tokens is not None:
375+
if not is_omitted(max_output_tokens):
339376
max_new_tokens = max_output_tokens
340377

341378
stop = kwargs.get("stop", None)
@@ -359,7 +396,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
359396
input_ids=prompt_token_ids,
360397
gconfig=gconfig,
361398
rid=str(uuid.uuid4()),
362-
metadata=metadata if metadata is not NOT_GIVEN else {},
399+
metadata=metadata if not is_omitted(metadata) else {},
363400
tokenizer=self.tokenizer,
364401
)
365402

@@ -369,7 +406,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
369406

370407
# Parse tool calls.
371408
tool_calls = None
372-
if tool_choice != "none" and tools:
409+
if not is_omitted(tool_choice) and tool_choice != "none" and tools:
373410
tool_calls, output_text, engine_resp.stop_reason = process_tool_calls(
374411
output_text,
375412
tools,
@@ -420,14 +457,14 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
420457
created_at=current_time,
421458
error=None,
422459
incomplete_details=None,
423-
instructions=None if instructions is NOT_GIVEN else instructions,
424-
metadata=None if metadata is NOT_GIVEN else metadata,
460+
instructions=None if is_omitted(instructions) else instructions,
461+
metadata=None if is_omitted(metadata) else metadata,
425462
model="None",
426463
object="response",
427464
output=resp_output,
428465
parallel_tool_calls=False,
429466
temperature=temp,
430-
tool_choice=tool_choice if tool_choice is not NOT_GIVEN else "none",
467+
tool_choice=tool_choice if not is_omitted(tool_choice) else "none",
431468
tools=tools,
432469
top_p=top_p_val,
433470
background=None,
@@ -453,7 +490,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
453490
response=deepcopy(response),
454491
model_response=engine_resp, # Should not deepcopy because of tokenizer
455492
input_data=(
456-
deepcopy(input) if input is not NOT_GIVEN else ""
493+
deepcopy(input) if not is_omitted(input) else ""
457494
), # Store a copy of the input data
458495
chat_template_type=self.chat_template_type,
459496
)
@@ -751,3 +788,23 @@ def export_responses(
751788
"export_responses is deprecated. Please use export_interactions instead."
752789
)
753790
return self.export_interactions(style)
791+
792+
793+
def is_omitted(value) -> bool:
794+
"""Check if a value is NOT_GIVEN or Omit type or None."""
795+
if value is NOT_GIVEN or value is None:
796+
return True
797+
# Use isinstance for type safety and robustness
798+
# Check for common omitted types from OpenAI SDK
799+
try:
800+
from openai import Omit
801+
802+
if isinstance(value, Omit):
803+
return True
804+
except ImportError:
805+
pass
806+
807+
# Fallback for other omit types
808+
if hasattr(value, "__class__"):
809+
return value.__class__.__name__ in ("NotGiven", "Omit")
810+
return False

0 commit comments

Comments
 (0)