Skip to content

Commit 932aafc

Browse files
clean up controller and server
1 parent fed2448 commit 932aafc

File tree

2 files changed

+23
-49
lines changed

2 files changed

+23
-49
lines changed

src/agentlab/analyze/agent_controller.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import importlib
44
import logging
55
from io import BytesIO
6-
from pathlib import Path
76
import requests
87
import numpy as np
98
import PIL.Image

src/agentlab/analyze/server.py

Lines changed: 23 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,27 @@
22
import base64
33
import copy
44
import importlib
5-
import logging
6-
import time
7-
from typing import Any, Dict, Optional
85

96
import dotenv
107
import numpy as np
118
import uvicorn
129

1310
# Import your BrowserEnv and any task setup you need
1411
from bgym import DEFAULT_BENCHMARKS
15-
from browsergym.core.env import BrowserEnv
16-
from browsergym.core.task import AbstractBrowserTask
17-
from fastapi import FastAPI, Request
12+
from fastapi import FastAPI
1813
from pydantic import BaseModel
1914

2015
dotenv.load_dotenv()
2116

22-
logger = logging.getLogger(__name__)
23-
logger.setLevel(logging.INFO)
24-
2517
app = FastAPI()
2618

2719

28-
# Utils to import the action mapping fn
2920
def import_from_path(path):
3021
"""
31-
Import and instantiate a class, then return its 'to_python_code' method.
32-
For example, given 'browsergym.core.action.highlevel.HighLevelActionSet.to_python_code',
33-
this will instantiate HighLevelActionSet and return its to_python_code method.
22+
Util function to import and instantiate a class, then return a specific method.
23+
For example, given `browsergym.core.action.highlevel.HighLevelActionSet.to_python_code`,
24+
this will instantiate `HighLevelActionSet` and return its `to_python_code` method.
3425
"""
35-
import importlib
3626

3727
parts = path.split(".")
3828
# Find the module (the longest prefix that can be imported)
@@ -61,8 +51,11 @@ def import_from_path(path):
6151
return obj
6252

6353

64-
## Utils to convert to safe JSON response
6554
def make_json_safe(obj):
55+
"""
56+
Util function to convert numpy arrays and other non-JSON-serializable objects to JSON-serializable objects.
57+
Specifically, we convert numpy arrays to base64 encoded strings so that payloads are of reasonable size.
58+
"""
6659
if isinstance(obj, np.ndarray):
6760
# convert to base64
6861
return {"data": base64.b64encode(obj.tobytes()).decode("utf-8"), "shape": obj.shape, "dtype": str(obj.dtype)}
@@ -228,7 +221,12 @@ def status(self) -> dict:
228221
)
229222

230223
def prepare_benchmark(self) -> dict:
231-
start = time.time()
224+
"""
225+
Prepare the benchmark environment.
226+
227+
:return: Dictionary with status
228+
:rtype: dict
229+
"""
232230
if not self.info_set:
233231
return make_json_safe(
234232
{
@@ -241,25 +239,19 @@ def prepare_benchmark(self) -> dict:
241239
# close the current environment first
242240
self.env.close()
243241
self.env = None
244-
# then create the new environment
242+
243+
# prepare backends
245244
benchmark = DEFAULT_BENCHMARKS[self.benchmark_name]()
246245
benchmark.env_args_list = [
247246
elem for elem in benchmark.env_args_list if elem.task_name == self.task_name and str(elem.task_seed) == str(self.seed)
248247
]
249-
start = time.time()
250248
benchmark.prepare_backends()
251-
end = time.time()
252-
logger.info(f"prepare_backends done in {end - start}")
253249

254250
env_args = benchmark.env_args_list[0]
255251
self.action_mapping = import_from_path(self.action_mapping_fn)
256252

257253
# create environment
258-
start = time.time()
259254
self.env = env_args.make_env(self.action_mapping, self.exp_dir)
260-
print(self.env)
261-
end = time.time()
262-
logger.info(f"make_env done in {end - start}")
263255
return make_json_safe(
264256
{
265257
"status": "success",
@@ -273,7 +265,6 @@ def reload_task(self) -> dict:
273265
:return: Dictionary with status
274266
:rtype: dict
275267
"""
276-
start = time.time()
277268
if not self.info_set:
278269
return make_json_safe(
279270
{
@@ -289,19 +280,12 @@ def reload_task(self) -> dict:
289280
}
290281
)
291282

292-
tmp_start = time.time()
283+
# instead of resetting the whole environment, we go back to the original webpage and clear localStorage and sessionStorage
284+
# NOTE: this is not guaranteed to result in the exact same state, but we find that it works most of the time, is much
285+
# faster than resetting the whole environment, and ensures the seed of the environment remains the same
293286
self.env.unwrapped.page.goto(self.start_url, wait_until="load")
294-
tmp_end = time.time()
295-
logger.info(f"goto done in {tmp_end - tmp_start}")
296-
tmp_start = time.time()
297287
self.env.unwrapped.page.evaluate("window.localStorage.clear(); window.sessionStorage.clear();")
298-
299288
obs = self.env.unwrapped._get_obs()
300-
tmp_end = time.time()
301-
logger.info(f"clear storage done in {tmp_end - tmp_start}")
302-
303-
end = time.time()
304-
logger.info(f"reload_task done in {end - start}")
305289

306290
self.last_obs = copy.deepcopy(obs)
307291
self.last_info = copy.deepcopy(self.start_info)
@@ -320,7 +304,6 @@ def reset(self) -> dict:
320304
:return: Dictionary with obs and info
321305
:rtype: dict
322306
"""
323-
start = time.time()
324307
if not self.info_set:
325308
return make_json_safe(
326309
{
@@ -336,11 +319,8 @@ def reset(self) -> dict:
336319
}
337320
)
338321

339-
# finally, reset the environment
340-
start = time.time()
322+
# reset the environment
341323
obs, info = self.env.reset(seed=self.seed)
342-
end = time.time()
343-
logger.info(f"env reset done in {end - start}")
344324

345325
self.last_obs = copy.deepcopy(obs)
346326
self.last_info = copy.deepcopy(info)
@@ -370,14 +350,12 @@ def step(self, action: str) -> dict:
370350
"message": "Environment not created. Please create an environment first.",
371351
}
372352
)
373-
start = time.time()
353+
# step the environment
374354
obs, reward, terminated, truncated, info = self.env.step(action)
375-
end = time.time()
376-
logger.info(f"env step done in {end - start}")
377-
start = time.time()
355+
378356
self.last_obs = copy.deepcopy(obs)
379357
self.last_info = copy.deepcopy(info)
380-
out = make_json_safe(
358+
return make_json_safe(
381359
{
382360
"status": "success",
383361
"message": "Environment stepped successfully.",
@@ -388,9 +366,6 @@ def step(self, action: str) -> dict:
388366
"info": info,
389367
}
390368
)
391-
end = time.time()
392-
logger.info(f"obs copied in {end - start}")
393-
return out
394369

395370
def get_obs(self) -> dict:
396371
"""Get the last observation

0 commit comments

Comments
 (0)