Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 80 additions & 23 deletions areal/experimental/openai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
)
from openai.types.chat.chat_completion import Choice
Expand Down Expand Up @@ -106,7 +107,7 @@ async def create(
if extra_body is None:
extra_body = {}
# Convert messages to prompt format
tools = tools if tools is not NOT_GIVEN else None
tools = tools if not is_omitted(tools) else None
if self.chat_template_type == "hf":
prompt_token_ids = self.tokenizer.apply_chat_template(
messages_list,
Expand All @@ -128,23 +129,23 @@ async def create(
f"Unsupported chat_template_type {self.chat_template_type}"
)

temp = 1.0 if temperature is NOT_GIVEN else (temperature or 0.0)
temp = 1.0 if is_omitted(temperature) else (temperature or 0.0)
max_new_tokens = 512
if max_tokens is not NOT_GIVEN and max_tokens is not None:
if not is_omitted(max_tokens):
max_new_tokens = max_tokens - len(prompt_token_ids)
if max_new_tokens <= 0:
raise RuntimeError(
"max_tokens must be greater than the number of prompt tokens"
)
if max_completion_tokens is not NOT_GIVEN and max_completion_tokens is not None:
if not is_omitted(max_completion_tokens):
max_new_tokens = min(max_new_tokens, max_completion_tokens)

top_p_val = 1.0 if top_p is NOT_GIVEN else (top_p or 1.0)
stop_tokens = None if stop is NOT_GIVEN else stop
top_p_val = 1.0 if is_omitted(top_p) else (top_p or 1.0)
stop_tokens = None if is_omitted(stop) else stop
if stop_tokens is not None and not isinstance(stop_tokens, list):
stop_tokens = [stop_tokens]

if frequency_penalty is NOT_GIVEN or frequency_penalty is None:
if is_omitted(frequency_penalty):
frequency_penalty = 0.0

# Create generation config
Expand All @@ -165,7 +166,7 @@ async def create(
input_ids=prompt_token_ids,
gconfig=gconfig,
rid=str(uuid.uuid4()),
metadata=metadata if metadata is not NOT_GIVEN else {},
metadata=metadata if not is_omitted(metadata) else {},
tokenizer=self.tokenizer,
)

Expand Down Expand Up @@ -215,7 +216,7 @@ async def create(
),
)

if store is NOT_GIVEN or store:
if is_omitted(store) or store:
# Cache the completion with its input messages
self._cache[completion_id] = InteractionWithTokenLogpReward(
completion=deepcopy(chat_completion),
Expand Down Expand Up @@ -269,13 +270,45 @@ async def create(

# Build a simple messages list compatible with tokenizer chat template
messages_list: list[dict] = []
if instructions is not NOT_GIVEN and instructions is not None:
if not is_omitted(instructions):
messages_list = [
{"role": "system", "content": instructions},
]
if input is NOT_GIVEN or input is None:
if is_omitted(input):
raise ValueError("input is required for Responses.create")

def _convert_tool_output_format(
item: dict,
) -> ChatCompletionToolMessageParam | dict:
"""Convert custom tool output format to standard chat template format.

Converts openai.types.responses.response_input_item_param.FunctionCallOutput
to openai.types.chat.ChatCompletionToolMessageParam.

Args:
item: Input dict, could be FunctionCallOutput from openai-agents SDK
with format: {'call_id': str, 'output': str, 'type': 'function_call_output'}

Returns:
ChatCompletionToolMessageParam (TypedDict) with format:
{'role': 'tool', 'content': str, 'tool_call_id': str}
or the original dict if conversion is not needed.
"""
if (
isinstance(item, dict)
and "output" in item
and item.get("type") == "function_call_output"
):
converted = {
"role": "tool",
"content": item["output"],
}
# Add tool_call_id if present
if "call_id" in item:
converted["tool_call_id"] = item["call_id"]
return converted
return item

def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
messages_list = []
if "content" in item:
Expand All @@ -286,13 +319,17 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
elif isinstance(item["content"], Iterable):
for content in item["content"]:
if isinstance(content, dict):
messages_list.append(deepcopy(content))
# Convert tool output format if needed
converted = _convert_tool_output_format(content)
messages_list.append(deepcopy(converted))
else:
raise ValueError("Unsupported content format")
else:
raise ValueError("Unsupported input item format")
else:
messages_list.append(deepcopy(item))
# Convert tool output format if needed
converted = _convert_tool_output_format(item)
messages_list.append(deepcopy(converted))
return messages_list

if isinstance(input, str):
Expand All @@ -309,7 +346,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
)

# Apply chat template
tools = list(tools) if tools is not NOT_GIVEN else None
tools = list(tools) if not is_omitted(tools) else None
if self.chat_template_type == "hf":
prompt_token_ids = self.tokenizer.apply_chat_template(
messages_list,
Expand All @@ -332,10 +369,10 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
)

# Map sampling params
temp = 1.0 if temperature is NOT_GIVEN else (temperature or 0.0)
top_p_val = 1.0 if top_p is NOT_GIVEN else (top_p or 1.0)
temp = 1.0 if is_omitted(temperature) else (temperature or 0.0)
top_p_val = 1.0 if is_omitted(top_p) else (top_p or 1.0)
max_new_tokens = 512
if max_output_tokens is not NOT_GIVEN and max_output_tokens is not None:
if not is_omitted(max_output_tokens):
max_new_tokens = max_output_tokens

stop = kwargs.get("stop", None)
Expand All @@ -359,7 +396,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
input_ids=prompt_token_ids,
gconfig=gconfig,
rid=str(uuid.uuid4()),
metadata=metadata if metadata is not NOT_GIVEN else {},
metadata=metadata if not is_omitted(metadata) else {},
tokenizer=self.tokenizer,
)

Expand All @@ -369,7 +406,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:

# Parse tool calls.
tool_calls = None
if tool_choice != "none" and tools:
if not is_omitted(tool_choice) and tool_choice != "none" and tools:
tool_calls, output_text, engine_resp.stop_reason = process_tool_calls(
output_text,
tools,
Expand Down Expand Up @@ -420,14 +457,14 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
created_at=current_time,
error=None,
incomplete_details=None,
instructions=None if instructions is NOT_GIVEN else instructions,
metadata=None if metadata is NOT_GIVEN else metadata,
instructions=None if is_omitted(instructions) else instructions,
metadata=None if is_omitted(metadata) else metadata,
model="None",
object="response",
output=resp_output,
parallel_tool_calls=False,
temperature=temp,
tool_choice=tool_choice if tool_choice is not NOT_GIVEN else "none",
tool_choice=tool_choice if not is_omitted(tool_choice) else "none",
tools=tools,
top_p=top_p_val,
background=None,
Expand All @@ -453,7 +490,7 @@ def _build_messages_list(item: ResponseInputItemParam) -> list[dict]:
response=deepcopy(response),
model_response=engine_resp, # Should not deepcopy because of tokenizer
input_data=(
deepcopy(input) if input is not NOT_GIVEN else ""
deepcopy(input) if not is_omitted(input) else ""
), # Store a copy of the input data
chat_template_type=self.chat_template_type,
)
Expand Down Expand Up @@ -751,3 +788,23 @@ def export_responses(
"export_responses is deprecated. Please use export_interactions instead."
)
return self.export_interactions(style)


def is_omitted(value) -> bool:
"""Check if a value is NOT_GIVEN or Omit type or None."""
if value is NOT_GIVEN or value is None:
return True
# Use isinstance for type safety and robustness
# Check for common omitted types from OpenAI SDK
try:
from openai import Omit

if isinstance(value, Omit):
return True
except ImportError:
pass

# Fallback for other omit types
if hasattr(value, "__class__"):
return value.__class__.__name__ in ("NotGiven", "Omit")
return False