Skip to content

Commit 4052cb0

Browse files
add updating action and thought, add reload task instead of reset
1 parent 444d2d0 commit 4052cb0

File tree

2 files changed

+166
-28
lines changed

2 files changed

+166
-28
lines changed

src/agentlab/analyze/agent_controller.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def set_task_selector():
174174

175175
prepare_agent()
176176
set_environment_info()
177+
prepare_benchmark()
177178
reset_environment()
178179

179180

@@ -216,6 +217,16 @@ def set_environment_info():
216217
logger.info(f"Done in {end - start}")
217218

218219

220+
def prepare_benchmark():
221+
logger.info("Preparing benchmark...")
222+
start = datetime.now()
223+
resp = requests.post(f"{SERVER_URL}/prepare_benchmark")
224+
if resp.status_code != 200 or resp.json().get("status") != "success":
225+
st.error(resp.json())
226+
end = datetime.now()
227+
logger.info(f"Done in {end - start}")
228+
229+
219230
def reset_environment():
220231
logger.info("Restarting environment...")
221232
start = datetime.now()
@@ -242,6 +253,29 @@ def reset_environment():
242253
logger.info(f"Done postproc in {end - start}")
243254

244255

256+
def reload_task():
257+
logger.info("Reloading task...")
258+
start = datetime.now()
259+
resp = requests.post(f"{SERVER_URL}/reload_task")
260+
if resp.status_code != 200 or resp.json().get("status") != "success":
261+
print(resp.status_code)
262+
print(resp.json()["status"])
263+
print(resp.json()["message"])
264+
response_json = resp.json()
265+
if "obs" in response_json:
266+
if "screenshot" in response_json["obs"]:
267+
screenshot_data = response_json["obs"]["screenshot"]
268+
# convert base64 to numpy array
269+
screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"]))
270+
screenshot = screenshot.reshape(screenshot_data["shape"])
271+
response_json["obs"]["screenshot"] = screenshot
272+
if st.session_state.agent.obs_preprocessor:
273+
response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"])
274+
st.session_state.last_obs = response_json["obs"]
275+
end = datetime.now()
276+
logger.info(f"Done in {end - start}")
277+
278+
245279
def step_environment(action):
246280
logger.info("Stepping environment...")
247281
start = datetime.now()
@@ -269,7 +303,7 @@ def step_environment(action):
269303

270304

271305
def restore_environment():
272-
reset_environment()
306+
reload_task()
273307
for action in st.session_state.actions_history:
274308
step_environment(action)
275309

@@ -285,21 +319,46 @@ def get_action():
285319

286320

287321
def set_agent_state_box():
322+
323+
# Custom CSS to set textarea style same as code block
324+
st.markdown(
325+
"""
326+
<style>
327+
@import url('https://fonts.googleapis.com/css2?family=Handlee&family=IBM+Plex+Mono:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;1,100;1,200;1,300;1,400;1,500;1,600;1,700&family=Sedgwick+Ave&display=swap');
328+
textarea, .stTextArea textarea {
329+
font-family: "IBM Plex Mono", monospace !important;
330+
font-size: 14px !important;
331+
font-weight: 400;
332+
font-style: normal;
333+
line-height: 1.6 !important;
334+
padding-top: 18px !important;
335+
background-color: #F8F9FB !important;
336+
337+
}
338+
</style>
339+
""",
340+
unsafe_allow_html=True,
341+
)
342+
288343
# set agent state and goal box
289344
with st.container():
290345
col1, col2, col3 = st.columns([1, 1, 1])
291346
with col1:
292347
with st.container(border=True, height=250):
293348
st.markdown("**Goal**")
349+
# st.text_area("", st.session_state.agent.obs_history[-1]["goal"], height=175, disabled=True, label_visibility="collapsed")
294350
st.code(st.session_state.agent.obs_history[-1]["goal"], wrap_lines=True, language=None, height=175)
295351
with col2:
296352
with st.container(border=True, height=250):
297353
st.markdown("**Think**")
298-
st.code(st.session_state.action_info.think, wrap_lines=True, language=None, height=175)
354+
st.session_state.action_info.think = st.text_area(
355+
"Think", st.session_state.action_info.think, height=172, label_visibility="collapsed"
356+
)
299357
with col3:
300358
with st.container(border=True, height=250):
301359
st.markdown("**Action**")
302-
st.code(st.session_state.action, wrap_lines=True, language="python", height=175)
360+
st.session_state.action = st.text_area("Action", st.session_state.action, height=172, label_visibility="collapsed")
361+
# st.code(st.session_state.action, wrap_lines=True, language="python", height=175)
303362

304363

305364
def set_prompt_modifier():

src/agentlab/analyze/server.py

Lines changed: 104 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def __init__(self):
106106
self.last_obs = None
107107
self.last_info = None
108108

109+
# used to reload task
110+
self.start_info = None
111+
self.start_url = None
112+
109113
def set_info(
110114
self,
111115
benchmark_name: str,
@@ -223,12 +227,7 @@ def status(self) -> dict:
223227
}
224228
)
225229

226-
def reset(self) -> dict:
227-
"""Reset the environment
228-
229-
:return: Dictionary with obs and info
230-
:rtype: dict
231-
"""
230+
def prepare_benchmark(self) -> dict:
232231
start = time.time()
233232
if not self.info_set:
234233
return make_json_safe(
@@ -237,54 +236,124 @@ def reset(self) -> dict:
237236
"message": "Environment info not set. Please set the environment info first.",
238237
}
239238
)
239+
240240
if self.env is not None:
241241
# close the current environment first
242242
self.env.close()
243243
self.env = None
244-
245244
# then create the new environment
246245
benchmark = DEFAULT_BENCHMARKS[self.benchmark_name]()
247246
benchmark.env_args_list = [
248247
elem for elem in benchmark.env_args_list if elem.task_name == self.task_name and str(elem.task_seed) == str(self.seed)
249248
]
249+
start = time.time()
250250
benchmark.prepare_backends()
251+
end = time.time()
252+
logger.info(f"prepare_backends done in {end - start}")
251253

252254
env_args = benchmark.env_args_list[0]
253-
# env_args.headless = False
254-
255255
self.action_mapping = import_from_path(self.action_mapping_fn)
256-
end = time.time()
257-
logger.info(f"init reset done in {end - start}")
256+
257+
# create environment
258258
start = time.time()
259259
self.env = env_args.make_env(self.action_mapping, self.exp_dir)
260+
print(self.env)
260261
end = time.time()
261262
logger.info(f"make_env done in {end - start}")
263+
return make_json_safe(
264+
{
265+
"status": "success",
266+
"message": "Environment prepared successfully.",
267+
}
268+
)
269+
270+
def reload_task(self) -> dict:
271+
"""Reload the task
272+
273+
:return: Dictionary with status
274+
:rtype: dict
275+
"""
276+
start = time.time()
277+
if not self.info_set:
278+
return make_json_safe(
279+
{
280+
"status": "error",
281+
"message": "Environment info not set. Please set the environment info first.",
282+
}
283+
)
284+
elif not self.env:
285+
return make_json_safe(
286+
{
287+
"status": "error",
288+
"message": "Environment not created. Please create an environment first.",
289+
}
290+
)
291+
292+
tmp_start = time.time()
293+
self.env.unwrapped.page.goto(self.start_url, wait_until="load")
294+
tmp_end = time.time()
295+
logger.info(f"goto done in {tmp_end - tmp_start}")
296+
tmp_start = time.time()
297+
self.env.unwrapped.page.evaluate("window.localStorage.clear(); window.sessionStorage.clear();")
298+
299+
obs = self.env.unwrapped._get_obs()
300+
tmp_end = time.time()
301+
logger.info(f"clear storage done in {tmp_end - tmp_start}")
302+
303+
end = time.time()
304+
logger.info(f"reload_task done in {end - start}")
305+
306+
self.last_obs = copy.deepcopy(obs)
307+
self.last_info = copy.deepcopy(self.start_info)
308+
return make_json_safe(
309+
{
310+
"status": "success",
311+
"message": "Task reloaded successfully.",
312+
"obs": self.last_obs,
313+
"info": self.last_info,
314+
}
315+
)
316+
317+
def reset(self) -> dict:
318+
"""Reset the environment
319+
320+
:return: Dictionary with obs and info
321+
:rtype: dict
322+
"""
262323
start = time.time()
324+
if not self.info_set:
325+
return make_json_safe(
326+
{
327+
"status": "error",
328+
"message": "Environment info not set. Please set the environment info first.",
329+
}
330+
)
331+
elif not self.env:
332+
return make_json_safe(
333+
{
334+
"status": "error",
335+
"message": "Environment not created. Please create an environment first.",
336+
}
337+
)
338+
263339
# finally, reset the environment
340+
start = time.time()
264341
obs, info = self.env.reset(seed=self.seed)
265-
self.last_obs = copy.deepcopy(obs)
266-
self.last_info = copy.deepcopy(info)
267342
end = time.time()
268343
logger.info(f"env reset done in {end - start}")
269-
start = time.time()
270-
# out = make_json_safe(
271-
out = make_json_safe(
344+
345+
self.last_obs = copy.deepcopy(obs)
346+
self.last_info = copy.deepcopy(info)
347+
self.start_info = copy.deepcopy(info)
348+
self.start_url = copy.deepcopy(self.env.unwrapped.page.url)
349+
return make_json_safe(
272350
{
273351
"status": "success",
274352
"message": "Environment reset successfully",
275353
"obs": self.last_obs,
276354
"info": self.last_info,
277355
}
278356
)
279-
end = time.time()
280-
logger.info(f"payload cleaned in {end - start}")
281-
# log payload size
282-
from pympler import asizeof
283-
284-
logger.info(f"Payload size: {asizeof.asizeof(out)} bytes")
285-
# print(out)
286-
# return {"status": "success", "message": "Environment reset successfully"}
287-
return out
288357

289358
def step(self, action: str) -> dict:
290359
"""Step the environment
@@ -398,6 +467,16 @@ def status():
398467
return env.status()
399468

400469

470+
@app.post("/prepare_benchmark")
471+
def prepare_benchmark():
472+
return env.prepare_benchmark()
473+
474+
475+
@app.post("/reload_task")
476+
def reload_task():
477+
return env.reload_task()
478+
479+
401480
@app.post("/reset")
402481
def reset():
403482
return env.reset()

0 commit comments

Comments
 (0)