1+ import ray
2+ import gym
3+ import numpy as np
4+
5+ # -----------------------------------------------------------------------------
6+ # Ray remote worker actor -----------------------------------------------------
7+ # -----------------------------------------------------------------------------
8+
9+ @ray .remote (num_cpus = 0.2 )
10+ class WebshopWorker :
11+ """Ray remote actor that replaces the worker function.
12+ Each actor hosts a *WebAgentTextEnv* instance.
13+ """
14+
15+ def __init__ (self , seed , env_kwargs ):
16+ # Lazy import avoids CUDA initialisation issues
17+ import sys
18+ import os
19+ project_root = os .path .abspath (os .path .join (os .path .dirname (__file__ ), 'webshop' ))
20+ sys .path .append (project_root )
21+ from web_agent_site .envs import WebAgentTextEnv # noqa: WPS433 (runtime import)
22+
23+ env_kwargs ['seed' ] = seed
24+ self .env = gym .make ('WebAgentTextEnv-v0' , ** env_kwargs )
25+
26+ def step (self , action ):
27+ """Execute a step in the environment"""
28+ obs , reward , done , info = self .env .step (action )
29+ info = dict (info or {}) # make a *copy* so we can mutate safely
30+ info ['available_actions' ] = self .env .get_available_actions ()
31+ info ['task_score' ] = reward
32+
33+ # Redefine reward. We only use rule-based reward - win for 10, lose for 0.
34+ if done and reward == 1.0 :
35+ info ['won' ] = True
36+ reward = 10.0
37+ else :
38+ info ['won' ] = False
39+ reward = 0
40+
41+ return obs , reward , done , info
42+
43+ def reset (self , idx ):
44+ """Reset the environment with given session index"""
45+ obs , info = self .env .reset (session = idx )
46+ info = dict (info or {})
47+ info ['available_actions' ] = self .env .get_available_actions ()
48+ info ['won' ] = False
49+ return obs , info
50+
51+ def render (self , mode_for_render ):
52+ """Render the environment"""
53+ rendered = self .env .render (mode = mode_for_render )
54+ return rendered
55+
56+ def get_available_actions (self ):
57+ """Get available actions"""
58+ return self .env .get_available_actions ()
59+
60+ def get_goals (self ):
61+ """Get environment goals"""
62+ return self .env .server .goals
63+
64+ def close (self ):
65+ """Close the environment"""
66+ self .env .close ()
67+
68+
69+ # -----------------------------------------------------------------------------
70+ # Vectorised Ray environment --------------------------------------------------
71+ # -----------------------------------------------------------------------------
72+
73+ class WebshopMultiProcessEnv (gym .Env ):
74+ """A vectorised, Ray-based wrapper around *WebAgentTextEnv*.
75+
76+ ``info`` dictionaries returned by :py:meth:`step` **and** :py:meth:`reset`
77+ automatically contain the key ``'available_actions'`` so downstream RL code
78+ can obtain the *legal* action set without extra IPC overhead.
79+ """
80+ def __init__ (
81+ self ,
82+ seed : int = 0 ,
83+ env_num : int = 1 ,
84+ group_n : int = 1 ,
85+ is_train : bool = True ,
86+ env_kwargs : dict = None ,
87+ ) -> None :
88+ super ().__init__ ()
89+
90+ # Initialize Ray if not already initialized
91+ if not ray .is_initialized ():
92+ ray .init ()
93+
94+ self .group_n = group_n
95+ self .env_num = env_num
96+ self .num_processes = env_num * group_n
97+ self .is_train = is_train
98+ if not is_train : assert group_n == 1
99+
100+ self ._rng = np .random .RandomState (seed )
101+
102+ self ._env_kwargs = env_kwargs if env_kwargs is not None else {'observation_mode' : 'text' , 'num_products' : None }
103+
104+ # -------------------------- Ray actors setup --------------------------
105+ self ._workers = []
106+
107+ for i in range (self .num_processes ):
108+ worker = WebshopWorker .remote (seed + (i // self .group_n ), self ._env_kwargs )
109+ self ._workers .append (worker )
110+
111+ # Get goals from the first worker
112+ goals_future = self ._workers [0 ].get_goals .remote ()
113+ goals = ray .get (goals_future )
114+
115+ # ------- original ----------#
116+ # if args.num is None:
117+ # if split == 'test':
118+ # self.goal_idxs = range(500)
119+ # elif split == 'eval':
120+ # self.goal_idxs = range(500, 1500)
121+ # elif split == 'train':
122+ # self.goal_idxs = range(1500, len(self.env.server.goals))
123+ # else:
124+ # self.goal_idxs = range(len(self.env.server.goals))
125+
126+ if not self .is_train :
127+ self .goal_idxs = range (500 )
128+ else :
129+ self .goal_idxs = range (500 , len (goals ))
130+
131+ print (self .goal_idxs )
132+
133+ # ------------------------------------------------------------------
134+ # Base API ----------------------------------------------------------
135+ # ------------------------------------------------------------------
136+
137+ def step (self , actions : list [str ]):
138+ if len (actions ) != self .num_processes :
139+ raise ValueError (
140+ f'Expected { self .num_processes } actions, got { len (actions )} ' ,
141+ )
142+
143+ # Send step commands to all workers
144+ futures = []
145+ for worker , action in zip (self ._workers , actions ):
146+ future = worker .step .remote (action )
147+ futures .append (future )
148+
149+ # Collect results
150+ results = ray .get (futures )
151+ obs_list , reward_list , done_list , info_list = [], [], [], []
152+ for obs , reward , done , info in results :
153+ obs_list .append (obs )
154+ reward_list .append (reward )
155+ done_list .append (done )
156+ info_list .append (info )
157+
158+ return obs_list , reward_list , done_list , info_list
159+
160+ def reset (self ):
161+ idx = self ._rng .choice (self .goal_idxs , size = self .env_num , replace = False )
162+ idx = np .repeat (idx , self .group_n ).tolist ()
163+
164+ # Send reset commands to all workers
165+ futures = []
166+ for worker , i in zip (self ._workers , idx ):
167+ future = worker .reset .remote (i )
168+ futures .append (future )
169+
170+ # Collect results
171+ results = ray .get (futures )
172+ obs_list , info_list = [], []
173+ for obs , info in results :
174+ obs_list .append (obs )
175+ info_list .append (info )
176+
177+ return obs_list , info_list
178+
179+ # ------------------------------------------------------------------
180+ # Convenience helpers ----------------------------------------------
181+ # ------------------------------------------------------------------
182+
183+ def render (self , mode : str = 'text' , env_idx : int = None ):
184+ if env_idx is not None :
185+ future = self ._workers [env_idx ].render .remote (mode )
186+ return ray .get (future )
187+
188+ futures = []
189+ for worker in self ._workers :
190+ future = worker .render .remote (mode )
191+ futures .append (future )
192+
193+ return ray .get (futures )
194+
195+ # ------------------------------------------------------------------
196+ # Clean‑up ----------------------------------------------------------
197+ # ------------------------------------------------------------------
198+
199+ def close (self ):
200+ if getattr (self , '_closed' , False ):
201+ return
202+
203+ # Close all workers and kill Ray actors
204+ close_futures = []
205+ for worker in self ._workers :
206+ future = worker .close .remote ()
207+ close_futures .append (future )
208+
209+ # Wait for all workers to close
210+ ray .get (close_futures )
211+
212+ # Kill all Ray actors
213+ for worker in self ._workers :
214+ ray .kill (worker )
215+
216+ self ._closed = True
217+
218+ def __del__ (self ): # noqa: D401
219+ self .close ()
220+
221+
222+ # -----------------------------------------------------------------------------
223+ # Factory helper --------------------------------------------------------------
224+ # -----------------------------------------------------------------------------
225+
226+ def build_webshop_envs (
227+ seed : int = 0 ,
228+ env_num : int = 1 ,
229+ group_n : int = 1 ,
230+ is_train : bool = True ,
231+ env_kwargs : dict = None ,
232+ ):
233+ """Mirror *build_sokoban_envs* so higher‑level code can swap seamlessly."""
234+ return WebshopMultiProcessEnv (
235+ seed = seed ,
236+ env_num = env_num ,
237+ group_n = group_n ,
238+ is_train = is_train ,
239+ env_kwargs = env_kwargs ,
240+ )
0 commit comments