Skip to content

Commit b7c799e

Browse files
clean up agent controller
1 parent 932aafc commit b7c799e

File tree

1 file changed

+24
-33
lines changed

1 file changed

+24
-33
lines changed

src/agentlab/analyze/agent_controller.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ def get_import_path(obj):
4141
return f"{obj.__module__}.{obj.__qualname__}"
4242

4343

44+
def deserialize_response(response_json):
45+
if "obs" in response_json:
46+
if "screenshot" in response_json["obs"]:
47+
screenshot_data = response_json["obs"]["screenshot"]
48+
# convert base64 to numpy array
49+
screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"]))
50+
screenshot = screenshot.reshape(screenshot_data["shape"])
51+
response_json["obs"]["screenshot"] = screenshot
52+
return response_json
53+
54+
4455
def setup_sidebar():
4556
with st.sidebar:
4657
st.markdown(
@@ -234,17 +245,11 @@ def reset_environment():
234245
logger.info(f"Done request in {end - start}")
235246
start = datetime.now()
236247
if resp.status_code != 200 or resp.json().get("status") != "success":
237-
print(resp.status_code)
238-
print(resp.json()["status"])
239-
print(resp.json()["message"])
248+
logger.error(resp.status_code)
249+
logger.error(resp.json()["status"])
250+
logger.error(resp.json()["message"])
240251
response_json = resp.json()
241-
if "obs" in response_json:
242-
if "screenshot" in response_json["obs"]:
243-
screenshot_data = response_json["obs"]["screenshot"]
244-
# convert base64 to numpy array
245-
screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"]))
246-
screenshot = screenshot.reshape(screenshot_data["shape"])
247-
response_json["obs"]["screenshot"] = screenshot
252+
response_json = deserialize_response(response_json)
248253
if st.session_state.agent.obs_preprocessor:
249254
response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"])
250255
st.session_state.last_obs = response_json["obs"]
@@ -257,17 +262,11 @@ def reload_task():
257262
start = datetime.now()
258263
resp = requests.post(f"{SERVER_URL}/reload_task")
259264
if resp.status_code != 200 or resp.json().get("status") != "success":
260-
print(resp.status_code)
261-
print(resp.json()["status"])
262-
print(resp.json()["message"])
265+
logger.error(resp.status_code)
266+
logger.error(resp.json()["status"])
267+
logger.error(resp.json()["message"])
263268
response_json = resp.json()
264-
if "obs" in response_json:
265-
if "screenshot" in response_json["obs"]:
266-
screenshot_data = response_json["obs"]["screenshot"]
267-
# convert base64 to numpy array
268-
screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"]))
269-
screenshot = screenshot.reshape(screenshot_data["shape"])
270-
response_json["obs"]["screenshot"] = screenshot
269+
response_json = deserialize_response(response_json)
271270
if st.session_state.agent.obs_preprocessor:
272271
response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"])
273272
st.session_state.last_obs = response_json["obs"]
@@ -281,17 +280,12 @@ def step_environment(action):
281280
payload = {"action": action}
282281
resp = requests.post(f"{SERVER_URL}/step", json=payload)
283282
if resp.status_code != 200 or resp.json().get("status") != "success":
284-
print(resp.status_code)
285-
print(resp.json()["status"])
286-
print(resp.json()["message"])
283+
logger.error(resp.status_code)
284+
logger.error(resp.json()["status"])
285+
logger.error(resp.json()["message"])
287286
response_json = resp.json()
288-
if "obs" in response_json:
289-
if "screenshot" in response_json["obs"]:
290-
screenshot_data = response_json["obs"]["screenshot"]
291-
# convert base64 to numpy array
292-
screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"]))
293-
screenshot = screenshot.reshape(screenshot_data["shape"])
294-
response_json["obs"]["screenshot"] = screenshot
287+
response_json = deserialize_response(response_json)
288+
295289
if st.session_state.agent.obs_preprocessor:
296290
response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"])
297291
st.session_state.last_obs = response_json["obs"]
@@ -345,7 +339,6 @@ def set_agent_state_box():
345339
with col1:
346340
with st.container(border=True, height=250):
347341
st.markdown("**Goal**")
348-
# st.text_area("", st.session_state.agent.obs_history[-1]["goal"], height=175, disabled=True, label_visibility="collapsed")
349342
st.code(st.session_state.agent.obs_history[-1]["goal"], wrap_lines=True, language=None, height=175)
350343
with col2:
351344
with st.container(border=True, height=250):
@@ -357,12 +350,10 @@ def set_agent_state_box():
357350
with st.container(border=True, height=250):
358351
st.markdown("**Action**")
359352
st.session_state.action = st.text_area("Action", st.session_state.action, height=172, label_visibility="collapsed")
360-
# st.code(st.session_state.action, wrap_lines=True, language="python", height=175)
361353

362354

363355
def set_prompt_modifier():
364356
with st.expander("**Prompt Modifier**", expanded=False):
365-
# st.write(st.session_state.agent.flags)
366357
st.markdown("**Observation Flags**")
367358
col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1])
368359
with col1:

0 commit comments

Comments
 (0)