22import base64
33import copy
44import importlib
5- import logging
6- import time
7- from typing import Any , Dict , Optional
85
96import dotenv
107import numpy as np
118import uvicorn
129
1310# Import your BrowserEnv and any task setup you need
1411from bgym import DEFAULT_BENCHMARKS
15- from browsergym .core .env import BrowserEnv
16- from browsergym .core .task import AbstractBrowserTask
17- from fastapi import FastAPI , Request
12+ from fastapi import FastAPI
1813from pydantic import BaseModel
1914
2015dotenv .load_dotenv ()
2116
22- logger = logging .getLogger (__name__ )
23- logger .setLevel (logging .INFO )
24-
2517app = FastAPI ()
2618
2719
28- # Utils to import the action mapping fn
2920def import_from_path (path ):
3021 """
31- Import and instantiate a class, then return its 'to_python_code' method.
32- For example, given ' browsergym.core.action.highlevel.HighLevelActionSet.to_python_code' ,
33- this will instantiate HighLevelActionSet and return its to_python_code method.
22+ Util function to import and instantiate a class, then return a specific method.
23+ For example, given ` browsergym.core.action.highlevel.HighLevelActionSet.to_python_code` ,
24+ this will instantiate ` HighLevelActionSet` and return its ` to_python_code` method.
3425 """
35- import importlib
3626
3727 parts = path .split ("." )
3828 # Find the module (the longest prefix that can be imported)
@@ -61,8 +51,11 @@ def import_from_path(path):
6151 return obj
6252
6353
64- ## Utils to convert to safe JSON response
6554def make_json_safe (obj ):
55+ """
56+ Util function to convert numpy arrays and other non-JSON-serializable objects to JSON-serializable objects.
57+ Specifically, we convert numpy arrays to base64 encoded strings so that payloads are of reasonable size.
58+ """
6659 if isinstance (obj , np .ndarray ):
6760 # convert to base64
6861 return {"data" : base64 .b64encode (obj .tobytes ()).decode ("utf-8" ), "shape" : obj .shape , "dtype" : str (obj .dtype )}
@@ -228,7 +221,12 @@ def status(self) -> dict:
228221 )
229222
230223 def prepare_benchmark (self ) -> dict :
231- start = time .time ()
224+ """
225+ Prepare the benchmark environment.
226+
227+ :return: Dictionary with status
228+ :rtype: dict
229+ """
232230 if not self .info_set :
233231 return make_json_safe (
234232 {
@@ -241,25 +239,19 @@ def prepare_benchmark(self) -> dict:
241239 # close the current environment first
242240 self .env .close ()
243241 self .env = None
244- # then create the new environment
242+
243+ # prepare backends
245244 benchmark = DEFAULT_BENCHMARKS [self .benchmark_name ]()
246245 benchmark .env_args_list = [
247246 elem for elem in benchmark .env_args_list if elem .task_name == self .task_name and str (elem .task_seed ) == str (self .seed )
248247 ]
249- start = time .time ()
250248 benchmark .prepare_backends ()
251- end = time .time ()
252- logger .info (f"prepare_backends done in { end - start } " )
253249
254250 env_args = benchmark .env_args_list [0 ]
255251 self .action_mapping = import_from_path (self .action_mapping_fn )
256252
257253 # create environment
258- start = time .time ()
259254 self .env = env_args .make_env (self .action_mapping , self .exp_dir )
260- print (self .env )
261- end = time .time ()
262- logger .info (f"make_env done in { end - start } " )
263255 return make_json_safe (
264256 {
265257 "status" : "success" ,
@@ -273,7 +265,6 @@ def reload_task(self) -> dict:
273265 :return: Dictionary with status
274266 :rtype: dict
275267 """
276- start = time .time ()
277268 if not self .info_set :
278269 return make_json_safe (
279270 {
@@ -289,19 +280,12 @@ def reload_task(self) -> dict:
289280 }
290281 )
291282
292- tmp_start = time .time ()
283+ # instead of resetting the whole environment, we go back to the original webpage and clear localStorage and sessionStorage
284+ # NOTE: this is not guaranteed to result in the exact same state, but we find that it works most of the time, is much
285+ # faster than resetting the whole environment, and ensures the seed of the environment remains the same
293286 self .env .unwrapped .page .goto (self .start_url , wait_until = "load" )
294- tmp_end = time .time ()
295- logger .info (f"goto done in { tmp_end - tmp_start } " )
296- tmp_start = time .time ()
297287 self .env .unwrapped .page .evaluate ("window.localStorage.clear(); window.sessionStorage.clear();" )
298-
299288 obs = self .env .unwrapped ._get_obs ()
300- tmp_end = time .time ()
301- logger .info (f"clear storage done in { tmp_end - tmp_start } " )
302-
303- end = time .time ()
304- logger .info (f"reload_task done in { end - start } " )
305289
306290 self .last_obs = copy .deepcopy (obs )
307291 self .last_info = copy .deepcopy (self .start_info )
@@ -320,7 +304,6 @@ def reset(self) -> dict:
320304 :return: Dictionary with obs and info
321305 :rtype: dict
322306 """
323- start = time .time ()
324307 if not self .info_set :
325308 return make_json_safe (
326309 {
@@ -336,11 +319,8 @@ def reset(self) -> dict:
336319 }
337320 )
338321
339- # finally, reset the environment
340- start = time .time ()
322+ # reset the environment
341323 obs , info = self .env .reset (seed = self .seed )
342- end = time .time ()
343- logger .info (f"env reset done in { end - start } " )
344324
345325 self .last_obs = copy .deepcopy (obs )
346326 self .last_info = copy .deepcopy (info )
@@ -370,14 +350,12 @@ def step(self, action: str) -> dict:
370350 "message" : "Environment not created. Please create an environment first." ,
371351 }
372352 )
373- start = time . time ()
353+ # step the environment
374354 obs , reward , terminated , truncated , info = self .env .step (action )
375- end = time .time ()
376- logger .info (f"env step done in { end - start } " )
377- start = time .time ()
355+
378356 self .last_obs = copy .deepcopy (obs )
379357 self .last_info = copy .deepcopy (info )
380- out = make_json_safe (
358+ return make_json_safe (
381359 {
382360 "status" : "success" ,
383361 "message" : "Environment stepped successfully." ,
@@ -388,9 +366,6 @@ def step(self, action: str) -> dict:
388366 "info" : info ,
389367 }
390368 )
391- end = time .time ()
392- logger .info (f"obs copied in { end - start } " )
393- return out
394369
395370 def get_obs (self ) -> dict :
396371 """Get the last observation
0 commit comments