Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions areal/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

VALID_DATASETS = ["gsm8k", "clevr_count_70k", "geometry3k", "hh-rlhf", "torl_data"]
VALID_DATASETS = [
"gsm8k",
"clevr_count_70k",
"geometry3k",
"hh-rlhf",
"torl_data",
"terminal_bench",
]

logger = logging.getLogger("Dataset")

Expand All @@ -24,7 +31,6 @@ def _get_custom_dataset(
processor: Optional["ProcessorMixin"] = None,
**kwargs,
) -> "Dataset":

if "gsm8k" in path and type == "sft":
from .gsm8k import get_gsm8k_sft_dataset

Expand Down Expand Up @@ -105,6 +111,16 @@ def _get_custom_dataset(
max_length=max_length,
**kwargs,
)
elif "terminal_bench" in path and type == "rl":
from .terminal_bench import get_terminal_bench_rl_dataset

return get_terminal_bench_rl_dataset(
path=path,
split=split,
tokenizer=tokenizer,
max_length=max_length,
**kwargs,
)
else:
raise ValueError(
f"Dataset {path} with split {split} and training type {type} is not supported. "
Expand Down
60 changes: 60 additions & 0 deletions areal/dataset/terminal_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import TYPE_CHECKING

from datasets import load_dataset

if TYPE_CHECKING:
from transformers import PreTrainedTokenizerFast


def get_terminal_bench_rl_dataset(
path: str,
split: str,
tokenizer: "PreTrainedTokenizerFast",
max_length: int | None = None,
):
"""Load terminal-bench dataset for RL training.

The dataset should be in parquet format with the following columns:
- prompt: The formatted prompt for the task
- task_name: Name of the task
- instruction: Raw instruction text
- extra_info: JSON string containing task metadata
"""
# Load from parquet file
dataset = load_dataset("parquet", data_files={split: path}, split=split)

# The dataset already has the right format from the converter:
# - prompt: contains the formatted conversation
# - task_name, instruction, extra_info: metadata fields

# For RL training, we need to extract messages from the prompt or extra_info
def process(sample):
# The prompt is already formatted, but we need to extract the instruction
# to create a messages structure for the workflow
instruction = sample.get("instruction", "")
task_name = sample.get("task_name", "")
dockerfile_contents = sample.get("dockerfile_contents", "")

# Return data in the format expected by the workflow
return {
"instruction": instruction,
"task_name": task_name,
"dockerfile_contents": dockerfile_contents,
"extra_info": sample.get("extra_info", ""),
"data_source": sample.get("data_source", "terminal_bench"),
}

dataset = dataset.map(process)

# Filter out sequences longer than max_length if specified
if max_length is not None:

def filter_length(samples):
# Tokenize instructions in batches for efficiency
instructions = samples["instruction"]
tokens_list = tokenizer(instructions, add_special_tokens=False)["input_ids"]
return [len(tokens) <= max_length for tokens in tokens_list]

dataset = dataset.filter(filter_length, batched=True)

return dataset
18 changes: 3 additions & 15 deletions areal/experimental/openai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
)
from openai.types.chat.chat_completion import Choice
Expand Down Expand Up @@ -277,22 +276,11 @@ async def create(
if is_omitted(input):
raise ValueError("input is required for Responses.create")

def _convert_tool_output_format(
item: dict,
) -> ChatCompletionToolMessageParam | dict:
def _convert_tool_output_format(item: dict) -> dict:
"""Convert custom tool output format to standard chat template format.

Converts openai.types.responses.response_input_item_param.FunctionCallOutput
to openai.types.chat.ChatCompletionToolMessageParam.

Args:
item: Input dict, could be FunctionCallOutput from openai-agents SDK
with format: {'call_id': str, 'output': str, 'type': 'function_call_output'}

Returns:
ChatCompletionToolMessageParam (TypedDict) with format:
{'role': 'tool', 'content': str, 'tool_call_id': str}
or the original dict if conversion is not needed.
Converts from: {'call_id': ..., 'output': ..., 'type': 'function_call_output'}
To: {'role': 'tool', 'content': ..., 'tool_call_id': ...}
"""
if (
isinstance(item, dict)
Expand Down
Binary file added assets/qwen3_8b_terminal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added examples/__init__.py
Empty file.
202 changes: 202 additions & 0 deletions examples/openai-agents/agent_terminal_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import asyncio
import logging
import os

from agents import Agent as OpenAIAgent
from agents import ModelSettings, OpenAIProvider, RunConfig, SQLiteSession
from agents import Runner as OpenAIRunner
from terminal.env import TerminalEnv
from terminal.judge_agent import JudgeAgent, judge_from_env
from terminal.prompt import SYSTEM_PROMPT
from transformers import PreTrainedTokenizerFast

from areal.api.cli_args import GenerationHyperparameters
from areal.api.workflow_api import RolloutWorkflow
from areal.experimental.openai import ArealOpenAI
from areal.utils import stats_tracker

logger = logging.getLogger(__name__)


class TerminalAgent:
def __init__(
self,
tokenizer: PreTrainedTokenizerFast,
max_tokens_per_turn: int = 1024,
max_turns: int = 8,
max_total_tokens: int = 32768,
dump_dir: str | None = None,
rollout_stat_scope: str = "rollout",
):
self.tokenizer = tokenizer
self.max_tokens_per_turn = max_tokens_per_turn
self.max_turns = max_turns
self.max_total_tokens = max_total_tokens
self.dump_dir = dump_dir
self.rollout_stat_scope = rollout_stat_scope

async def run_agent(self, data, client: ArealOpenAI, judge_agent: JudgeAgent):
"""Run the agent workflow for terminal task execution."""
run_config = RunConfig(
model_provider=OpenAIProvider(
openai_client=client,
use_responses=True,
),
tracing_disabled=True,
model_settings=ModelSettings(
temperature=1.0,
extra_args={"max_completion_tokens": self.max_tokens_per_turn},
tool_choice="auto",
store=True,
),
)

async with TerminalEnv(
task_name=data["task_name"],
dump_dir=self.dump_dir,
rollout_stat_scope=self.rollout_stat_scope,
) as env:
# Create agent workflow with terminal tools
agent = OpenAIAgent(
name="Terminal Task Agent",
instructions=SYSTEM_PROMPT,
tools=env.get_tools(),
)
session = SQLiteSession("terminal")
content = data["instruction"]

max_attempts = self.max_turns
reward = 0
judge_reward = 0
tracker = stats_tracker.get(self.rollout_stat_scope)

with tracker.record_timing("run_agent_total"):
error_count = 0.0
attempts_used = 0.0
for attempt in range(max_attempts):
attempts_used = float(attempt + 1)
try:
with tracker.record_timing("openai_runner_run"):
result = await OpenAIRunner.run(
agent,
input=content,
session=session,
run_config=run_config,
max_turns=30,
)
except Exception as e:
logger.error(f"Error running agent: {e}")
error_count += 1.0
break

with tracker.record_timing("env_validate_reward"):
reward = env.reward()
if judge_agent:
with tracker.record_timing("judge_agent_reward"):
judge_reward = await judge_agent.get_reward_from_judge(
session=session,
dockerfile_contents=data["dockerfile_contents"],
)
if judge_reward >= 0 and reward < 0.99:
reward = reward * 0.65 + judge_reward * 0.35

tracker.scalar(
reward=reward,
judge_reward=judge_reward,
attempt_index=float(attempt),
input_chars=float(len(content) if content else 0.0),
output_chars=float(
len(getattr(result, "final_output", "") or "")
),
)

if isinstance(reward, float) and reward >= 0.99:
tracker.scalar(success=1.0)
break

if attempt < max_attempts - 1:
content = f"""The previous attempt didn't complete the task successfully.
Please try a different approach.
Original task: {data["instruction"]}

Previous attempt result: {result.final_output}

Please analyze what went wrong and try again with a corrected approach."""
else:
content = f"""This is your final attempt. Please be extremely careful.
Original task: {data["instruction"]}

Previous attempts: {result.final_output}

Please provide a final, carefully executed solution."""
tracker.scalar(success=0.0)

tracker.scalar(
final_reward=reward, attempts_used=attempts_used, errors=error_count
)

client.set_final_reward(reward)

return reward


class TerminalAgentWorkflow(RolloutWorkflow):
def __init__(
self,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
dump_dir: str | None = None,
rollout_stat_scope: str = "rollout",
n_trajs: int = 1,
max_tokens: int = 32768,
max_turns: int = 8,
):
self.gconfig = gconfig
self.gconfig.n_samples = 1
self.tokenizer = tokenizer
self.dump_dir = dump_dir
self.max_tokens = max_tokens
self.rollout_stat_scope = rollout_stat_scope
if self.dump_dir is not None and not os.path.exists(self.dump_dir):
os.makedirs(self.dump_dir, exist_ok=True)

# Search hyper-parameters
self.n_trajs = n_trajs
self.agent = TerminalAgent(
tokenizer=self.tokenizer,
max_tokens_per_turn=self.gconfig.max_new_tokens,
max_turns=max_turns,
max_total_tokens=max_tokens,
dump_dir=self.dump_dir,
rollout_stat_scope=self.rollout_stat_scope,
)
self.judge_agent = judge_from_env()

async def arun_episode(self, engine, data):
clients = [
ArealOpenAI(
engine=engine, tokenizer=self.tokenizer, tool_call_parser="qwen25"
)
for _ in range(self.n_trajs)
]

# Collect trajectories
rewards = await asyncio.gather(
*[
self.agent.run_agent(
data=data,
client=clients[i],
judge_agent=self.judge_agent,
)
for i in range(self.n_trajs)
]
)
for reward in rewards:
stats_tracker.get(self.rollout_stat_scope).scalar(reward=reward)

interactions_with_reward = {}
for client in clients:
client.apply_reward_discount(turn_discount=0.9)
interactions = client.export_interactions(style="individual")
interactions_with_reward.update(interactions)
return interactions_with_reward
2 changes: 1 addition & 1 deletion examples/openai-agents/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ cluster:
type: nfs
nfs_record_root: /tmp/areal/name_resolve

allocation_mode: sglang.d4p1t1+d4p1t1
allocation_mode: sglang.d4p1t1+d1p1t1c4

rollout:
experiment_name: ${experiment_name}
Expand Down
Loading