Skip to content

Commit d3054cd

Browse files
Add codegen step-wise recoder agent
1 parent 380c69f commit d3054cd

File tree

1 file changed

+192
-0
lines changed

1 file changed

+192
-0
lines changed
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""Simple Codegen Agent
2+
3+
Captures human interactions using playwright inspector.
4+
Playwright trace logs are stored in "think" messages and can be viewed in Agentlab Xray.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import json
10+
import logging
11+
import tempfile
12+
import zipfile
13+
from dataclasses import dataclass
14+
from pathlib import Path
15+
16+
import bgym
17+
from playwright.sync_api import Page
18+
19+
from agentlab.agents.agent_args import AgentArgs
20+
from browsergym.core.observation import (
21+
extract_dom_extra_properties,
22+
extract_dom_snapshot,
23+
extract_focused_element_bid,
24+
extract_merged_axtree,
25+
extract_screenshot,
26+
)
27+
from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html
28+
29+
30+
def extract_log_message_from_pw_trace(pw_trace_file_path):
31+
zip_file = zipfile.ZipFile(pw_trace_file_path, "r")
32+
trace_lines = zip_file.read("trace.trace").decode("utf-8").splitlines()
33+
34+
actions = []
35+
for line in trace_lines:
36+
if line.strip():
37+
event = json.loads(line)
38+
if event.get("type") == "log":
39+
actions.append(event)
40+
# Extract log messages from the trace
41+
return [log["message"].strip() for log in sorted(actions, key=lambda x: x.get("time", 0))]
42+
43+
44+
def clean_pw_logs(logs, exclude_blacklist=True, use_substitutions=True):
45+
clean_logs = list(logs)
46+
blacklist = {
47+
"attempting click action",
48+
"waiting for element to be visible, enabled and stable",
49+
"element is visible, enabled and stable",
50+
"scrolling into view if needed",
51+
"done scrolling",
52+
"performing click action",
53+
"click action done",
54+
"waiting for scheduled navigations to finish",
55+
"navigations have finished",
56+
}
57+
58+
substitutions = [("waiting for ", "")]
59+
60+
def apply_substitutions(log):
61+
for old, new in substitutions:
62+
log = log.replace(old, new)
63+
return log
64+
65+
if exclude_blacklist:
66+
clean_logs = [log for log in clean_logs if log not in blacklist]
67+
if use_substitutions:
68+
clean_logs = [apply_substitutions(log) for log in clean_logs]
69+
70+
return clean_logs
71+
72+
73+
@dataclass
74+
class PlayWrightCodeGenAgentArgs(AgentArgs):
75+
agent_name: str = "PlayWrightCodeGenAgent"
76+
trace_dir: str = "playwright_codegen_traces"
77+
use_raw_page_output: bool = True
78+
store_raw_trace: bool = False
79+
80+
def make_agent(self) -> bgym.Agent: # type: ignore[override]
81+
return PlayWrightCodeGenAgent(self.trace_dir, self.store_raw_trace)
82+
83+
def set_reproducibility_mode(self):
84+
pass
85+
86+
87+
class PlayWrightCodeGenAgent(bgym.Agent):
88+
def __init__(self, trace_dir: str, store_raw_trace: bool):
89+
self.action_set = bgym.HighLevelActionSet(["bid"], multiaction=False)
90+
self._root = Path(trace_dir)
91+
self._page: Page | None = None
92+
self._step = 0
93+
self.store_raw_trace = store_raw_trace
94+
self._episode_trace_dir = None # Cache for single episode
95+
96+
def _get_trace_dir(self):
97+
"""Return the trace directory based on store_raw_trace setting."""
98+
if self._episode_trace_dir is None:
99+
if self.store_raw_trace:
100+
import datetime
101+
102+
dt_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
103+
self._episode_trace_dir = self._root / f"codegen_traces_{dt_str}"
104+
self._episode_trace_dir.mkdir(parents=True, exist_ok=True)
105+
else:
106+
self._episode_trace_dir = Path(tempfile.mkdtemp())
107+
return self._episode_trace_dir
108+
109+
def obs_preprocessor(self, obs: dict): # type: ignore[override]
110+
if isinstance(obs, dict):
111+
self._page = obs.get("page")
112+
obs["screenshot"] = extract_screenshot(self._page)
113+
obs["dom_object"] = extract_dom_snapshot(self._page)
114+
obs["axtree_object"] = extract_merged_axtree(self._page)
115+
scale_factor = getattr(self._page, "_bgym_scale_factor", 1.0)
116+
extra_properties = extract_dom_extra_properties(
117+
obs["dom_object"], scale_factor=scale_factor
118+
)
119+
obs["extra_element_properties"] = extra_properties
120+
obs["focused_element_bid"] = extract_focused_element_bid(self._page)
121+
122+
if obs["axtree_object"]:
123+
obs["axtree_txt"] = flatten_axtree_to_str(obs["axtree_object"])
124+
125+
if obs["dom_object"]:
126+
obs["dom_txt"] = flatten_dom_to_str(obs["dom_object"])
127+
obs["pruned_html"] = prune_html(obs["dom_txt"])
128+
129+
if "page" in obs: # unpickable
130+
del obs["page"]
131+
132+
return obs
133+
134+
def get_action(self, obs: dict): # type: ignore[override]
135+
136+
if self._page is None:
137+
raise RuntimeError("Playwright Page missing; ensure use_raw_page_output=True")
138+
139+
page = self._page
140+
trace_dir = self._get_trace_dir()
141+
trace_path = trace_dir / f"step_{self._step}.zip"
142+
page.context.tracing.start(screenshots=True, snapshots=True, sources=True)
143+
page.context.tracing.start_chunk(name=f"step_{self._step}")
144+
145+
print(
146+
f"{'─'*60}\n" f"Step {self._step}\n",
147+
f"{'─'*60}\n",
148+
"1. 🔴 Start Recording (Press 'Record' in the Playwright Inspector.)\n",
149+
"2. ✨ Perform actions for a single step.\n",
150+
"3. ⚫ Stop Recording (Press 'Record' again to stop recording.)\n",
151+
"4. ▶️ Press 'Resume' in the Playwright Inspector.",
152+
)
153+
154+
page.pause() # Launch Inspector and record actions
155+
page.context.tracing.stop_chunk(path=trace_path)
156+
page.context.tracing.stop()
157+
158+
pw_logs = extract_log_message_from_pw_trace(trace_path)
159+
pw_logs = clean_pw_logs(pw_logs, exclude_blacklist=True)
160+
pw_logs_str = "\n".join([f"{i}. {log}" for i, log in enumerate(pw_logs, 1)])
161+
162+
print(f"\n Playwright logs for step {self._step}:\n{pw_logs_str}")
163+
164+
self._step += 1
165+
166+
agent_info = bgym.AgentInfo(
167+
think=pw_logs_str,
168+
chat_messages=[],
169+
stats={},
170+
)
171+
172+
return "noop()", agent_info
173+
174+
175+
PW_CODEGEN_AGENT = PlayWrightCodeGenAgentArgs(store_raw_trace=True)
176+
177+
178+
if __name__ == "__main__":
179+
from agentlab.agents.human_trace_recorder.codegen_agent import PW_CODEGEN_AGENT
180+
from agentlab.experiments.study import Study
181+
182+
agent_configs = [PW_CODEGEN_AGENT]
183+
benchmark = bgym.DEFAULT_BENCHMARKS["workarena_l1"]() # type: bgym.Benchmark
184+
benchmark = benchmark.subset_from_glob("task_name", "*create*")
185+
benchmark.env_args_list = benchmark.env_args_list[:1]
186+
for env_args in benchmark.env_args_list:
187+
print(env_args.task_name)
188+
env_args.max_steps = 15
189+
env_args.headless = False
190+
191+
study = Study(agent_configs, benchmark, logging_level_stdout=logging.INFO)
192+
study.run(n_jobs=1, parallel_backend="sequential", n_relaunch=1)

0 commit comments

Comments
 (0)