Skip to content

Commit 5d693e1

Browse files
committed
perf: update openmanus rollout
1 parent 89cbaa9 commit 5d693e1

File tree

1 file changed

+36
-11
lines changed

1 file changed

+36
-11
lines changed

scripts/rollout/openmanus_rollout.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
pass
3030

3131

32-
class UnifiedAgent:
33-
"""Unified agent that can work with all environments"""
32+
class OpenManusAgent:
33+
"""OpenManus agent that can work with all environments"""
3434

35-
def __init__(self, model_name: str = "gpt-4o", temperature: float = 0.4,
35+
def __init__(self, model_name: str = "gpt-4o", temperature: float = 0.7,
3636
base_url: str | None = None, env_type: str = "alfworld"):
3737
self.model_name = model_name
3838
self.temperature = temperature
@@ -284,14 +284,14 @@ def main():
284284
parser.add_argument("--test_times", type=int, default=1,
285285
help="Number of test runs per batch")
286286
parser.add_argument("--max_steps", type=int, default=None,
287-
help="Maximum steps per episode (default: 50 for alfworld, 30 for gaia/webshop)")
287+
help="Maximum steps per episode (default: 30)")
288288
parser.add_argument("--seed", type=int, default=1)
289-
parser.add_argument("--history_length", type=int, default=2)
289+
parser.add_argument("--history_length", type=int, default=3)
290290

291291
# Model parameters
292-
parser.add_argument("--model", default="gpt-4o-mini",
292+
parser.add_argument("--model", default="gpt-4o",
293293
help="Model name (OpenAI: gpt-4o, gpt-4o-mini; Together: Qwen/Qwen2.5-7B-Instruct-Turbo, etc.)")
294-
parser.add_argument("--temperature", type=float, default=0.4)
294+
parser.add_argument("--temperature", type=float, default=0.7)
295295
parser.add_argument("--base_url", default=None,
296296
help="OpenAI-compatible base URL (e.g., vLLM http://127.0.0.1:8000/v1)")
297297

@@ -304,8 +304,14 @@ def main():
304304
# Output parameters
305305
parser.add_argument("--dump_path", default=None,
306306
help="If set, write JSONL trajectory to this file")
307-
parser.add_argument("--chat_root", default=None,
308-
help="If set, save per-episode chat histories under this root")
307+
parser.add_argument(
308+
"--chat_root",
309+
default=os.getcwd(),
310+
help=(
311+
"Root directory to save per-episode chat histories. "
312+
"Defaults to the current working directory."
313+
),
314+
)
309315

310316
# Environment-specific parameters
311317
parser.add_argument("--alf_env_type", default="alfworld/AlfredTWEnv",
@@ -339,7 +345,7 @@ def main():
339345
# Set default max_steps based on environment
340346
if args.max_steps is None:
341347
args.max_steps = {
342-
"alfworld": 50,
348+
"alfworld": 30,
343349
"gaia": 30,
344350
"webshop": 30
345351
}[args.env]
@@ -431,7 +437,7 @@ def _sanitize(s: str) -> str:
431437
sys.exit(0)
432438

433439
# Initialize agent (defer until after potential dry-run exit to avoid requiring API keys)
434-
agent = UnifiedAgent(
440+
agent = OpenManusAgent(
435441
model_name=args.model,
436442
temperature=args.temperature,
437443
base_url=args.base_url,
@@ -617,6 +623,12 @@ def _sanitize(s: str) -> str:
617623
"environment": args.env,
618624
}
619625

626+
# Add environment-specific task identifiers
627+
if args.env == "alfworld":
628+
meta["gamefile"] = infos[i].get("extra.gamefile", "")
629+
elif args.env == "gaia":
630+
meta["pid"] = infos[i].get("pid", "unknown")
631+
620632
with open(out_path, "w", encoding="utf-8") as f:
621633
json.dump({"messages": chats[i], "metadata": meta}, f, ensure_ascii=False, indent=2)
622634
saved_flags[i] = True
@@ -647,6 +659,13 @@ def _sanitize(s: str) -> str:
647659
"environment": args.env,
648660
}
649661

662+
# Add environment-specific task identifiers for unfinished tasks
663+
if last_infos and i < len(last_infos):
664+
if args.env == "alfworld":
665+
meta["gamefile"] = last_infos[i].get("extra.gamefile", "")
666+
elif args.env == "gaia":
667+
meta["pid"] = last_infos[i].get("pid", "unknown")
668+
650669
with open(out_path, "w", encoding="utf-8") as f:
651670
json.dump({"messages": chats[i], "metadata": meta}, f, ensure_ascii=False, indent=2)
652671
saved_flags[i] = True
@@ -700,6 +719,12 @@ def _sanitize(s: str) -> str:
700719
logging.info(f"Environment: {args.env}")
701720
logging.info(f"Total batches: {num_batches} | Batch size: {args.batch_size} | Total envs processed: {global_env_counter}")
702721

722+
# Echo save locations to make it easy to find outputs.
723+
if args.dump_path:
724+
logging.info(f"Trajectory file: {args.dump_path}")
725+
if chat_base_dir:
726+
logging.info(f"Chats directory: {chat_base_dir}")
727+
703728
if all_overall_success_rates:
704729
logging.info(
705730
f"Overall success avg ± std: "

0 commit comments

Comments
 (0)