Skip to content

Commit 477132f

Browse files
committed
[feat] add webshop environment
1 parent 85872b6 commit 477132f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+7450
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .projection import webshop_projection
2+
from .envs import build_webshop_envs
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import ray
2+
import gym
3+
import numpy as np
4+
5+
# -----------------------------------------------------------------------------
6+
# Ray remote worker actor -----------------------------------------------------
7+
# -----------------------------------------------------------------------------
8+
9+
@ray.remote(num_cpus=0.2)
10+
class WebshopWorker:
11+
"""Ray remote actor that replaces the worker function.
12+
Each actor hosts a *WebAgentTextEnv* instance.
13+
"""
14+
15+
def __init__(self, seed, env_kwargs):
16+
# Lazy import avoids CUDA initialisation issues
17+
import sys
18+
import os
19+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), 'webshop'))
20+
sys.path.append(project_root)
21+
from web_agent_site.envs import WebAgentTextEnv # noqa: WPS433 (runtime import)
22+
23+
env_kwargs['seed'] = seed
24+
self.env = gym.make('WebAgentTextEnv-v0', **env_kwargs)
25+
26+
def step(self, action):
27+
"""Execute a step in the environment"""
28+
obs, reward, done, info = self.env.step(action)
29+
info = dict(info or {}) # make a *copy* so we can mutate safely
30+
info['available_actions'] = self.env.get_available_actions()
31+
info['task_score'] = reward
32+
33+
# Redefine reward. We only use rule-based reward - win for 10, lose for 0.
34+
if done and reward == 1.0:
35+
info['won'] = True
36+
reward = 10.0
37+
else:
38+
info['won'] = False
39+
reward = 0
40+
41+
return obs, reward, done, info
42+
43+
def reset(self, idx):
44+
"""Reset the environment with given session index"""
45+
obs, info = self.env.reset(session=idx)
46+
info = dict(info or {})
47+
info['available_actions'] = self.env.get_available_actions()
48+
info['won'] = False
49+
return obs, info
50+
51+
def render(self, mode_for_render):
52+
"""Render the environment"""
53+
rendered = self.env.render(mode=mode_for_render)
54+
return rendered
55+
56+
def get_available_actions(self):
57+
"""Get available actions"""
58+
return self.env.get_available_actions()
59+
60+
def get_goals(self):
61+
"""Get environment goals"""
62+
return self.env.server.goals
63+
64+
def close(self):
65+
"""Close the environment"""
66+
self.env.close()
67+
68+
69+
# -----------------------------------------------------------------------------
70+
# Vectorised Ray environment --------------------------------------------------
71+
# -----------------------------------------------------------------------------
72+
73+
class WebshopMultiProcessEnv(gym.Env):
74+
"""A vectorised, Ray-based wrapper around *WebAgentTextEnv*.
75+
76+
``info`` dictionaries returned by :py:meth:`step` **and** :py:meth:`reset`
77+
automatically contain the key ``'available_actions'`` so downstream RL code
78+
can obtain the *legal* action set without extra IPC overhead.
79+
"""
80+
def __init__(
81+
self,
82+
seed: int = 0,
83+
env_num: int = 1,
84+
group_n: int = 1,
85+
is_train: bool = True,
86+
env_kwargs: dict = None,
87+
) -> None:
88+
super().__init__()
89+
90+
# Initialize Ray if not already initialized
91+
if not ray.is_initialized():
92+
ray.init()
93+
94+
self.group_n = group_n
95+
self.env_num = env_num
96+
self.num_processes = env_num * group_n
97+
self.is_train = is_train
98+
if not is_train: assert group_n == 1
99+
100+
self._rng = np.random.RandomState(seed)
101+
102+
self._env_kwargs = env_kwargs if env_kwargs is not None else {'observation_mode': 'text', 'num_products': None}
103+
104+
# -------------------------- Ray actors setup --------------------------
105+
self._workers = []
106+
107+
for i in range(self.num_processes):
108+
worker = WebshopWorker.remote(seed + (i // self.group_n), self._env_kwargs)
109+
self._workers.append(worker)
110+
111+
# Get goals from the first worker
112+
goals_future = self._workers[0].get_goals.remote()
113+
goals = ray.get(goals_future)
114+
115+
# ------- original ----------#
116+
# if args.num is None:
117+
# if split == 'test':
118+
# self.goal_idxs = range(500)
119+
# elif split == 'eval':
120+
# self.goal_idxs = range(500, 1500)
121+
# elif split == 'train':
122+
# self.goal_idxs = range(1500, len(self.env.server.goals))
123+
# else:
124+
# self.goal_idxs = range(len(self.env.server.goals))
125+
126+
if not self.is_train:
127+
self.goal_idxs = range(500)
128+
else:
129+
self.goal_idxs = range(500, len(goals))
130+
131+
print(self.goal_idxs)
132+
133+
# ------------------------------------------------------------------
134+
# Base API ----------------------------------------------------------
135+
# ------------------------------------------------------------------
136+
137+
def step(self, actions: list[str]):
138+
if len(actions) != self.num_processes:
139+
raise ValueError(
140+
f'Expected {self.num_processes} actions, got {len(actions)}',
141+
)
142+
143+
# Send step commands to all workers
144+
futures = []
145+
for worker, action in zip(self._workers, actions):
146+
future = worker.step.remote(action)
147+
futures.append(future)
148+
149+
# Collect results
150+
results = ray.get(futures)
151+
obs_list, reward_list, done_list, info_list = [], [], [], []
152+
for obs, reward, done, info in results:
153+
obs_list.append(obs)
154+
reward_list.append(reward)
155+
done_list.append(done)
156+
info_list.append(info)
157+
158+
return obs_list, reward_list, done_list, info_list
159+
160+
def reset(self):
161+
idx = self._rng.choice(self.goal_idxs, size=self.env_num, replace=False)
162+
idx = np.repeat(idx, self.group_n).tolist()
163+
164+
# Send reset commands to all workers
165+
futures = []
166+
for worker, i in zip(self._workers, idx):
167+
future = worker.reset.remote(i)
168+
futures.append(future)
169+
170+
# Collect results
171+
results = ray.get(futures)
172+
obs_list, info_list = [], []
173+
for obs, info in results:
174+
obs_list.append(obs)
175+
info_list.append(info)
176+
177+
return obs_list, info_list
178+
179+
# ------------------------------------------------------------------
180+
# Convenience helpers ----------------------------------------------
181+
# ------------------------------------------------------------------
182+
183+
def render(self, mode: str = 'text', env_idx: int = None):
184+
if env_idx is not None:
185+
future = self._workers[env_idx].render.remote(mode)
186+
return ray.get(future)
187+
188+
futures = []
189+
for worker in self._workers:
190+
future = worker.render.remote(mode)
191+
futures.append(future)
192+
193+
return ray.get(futures)
194+
195+
# ------------------------------------------------------------------
196+
# Clean‑up ----------------------------------------------------------
197+
# ------------------------------------------------------------------
198+
199+
def close(self):
200+
if getattr(self, '_closed', False):
201+
return
202+
203+
# Close all workers and kill Ray actors
204+
close_futures = []
205+
for worker in self._workers:
206+
future = worker.close.remote()
207+
close_futures.append(future)
208+
209+
# Wait for all workers to close
210+
ray.get(close_futures)
211+
212+
# Kill all Ray actors
213+
for worker in self._workers:
214+
ray.kill(worker)
215+
216+
self._closed = True
217+
218+
def __del__(self): # noqa: D401
219+
self.close()
220+
221+
222+
# -----------------------------------------------------------------------------
223+
# Factory helper --------------------------------------------------------------
224+
# -----------------------------------------------------------------------------
225+
226+
def build_webshop_envs(
227+
seed: int = 0,
228+
env_num: int = 1,
229+
group_n: int = 1,
230+
is_train: bool = True,
231+
env_kwargs: dict = None,
232+
):
233+
"""Mirror *build_sokoban_envs* so higher‑level code can swap seamlessly."""
234+
return WebshopMultiProcessEnv(
235+
seed=seed,
236+
env_num=env_num,
237+
group_n=group_n,
238+
is_train=is_train,
239+
env_kwargs=env_kwargs,
240+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import List
2+
import re
3+
4+
def webshop_projection(actions: List[str]):
5+
"""
6+
A function to process the actions.
7+
actions: the list of actions to be processed, it is a list of strings.
8+
Expected format:
9+
<think>some reasoning...</think><action>up/down/left/right/still</action>
10+
"""
11+
12+
valids = [0] * len(actions)
13+
14+
for i in range(len(actions)):
15+
original_str = actions[i] # keep the original string
16+
actions[i] = actions[i].lower()
17+
18+
# Attempt to extract the substring within <action>...</action>
19+
start_tag = "<action>"
20+
end_tag = "</action>"
21+
start_idx = actions[i].find(start_tag)
22+
end_idx = actions[i].find(end_tag)
23+
try:
24+
if start_idx == -1 or end_idx == -1:
25+
# If we can't find a valid <action>...</action> block, mark as invalid
26+
actions[i] = actions[i][-20:] # 0 is invalid action for Sokoban
27+
continue
28+
29+
# Extract just the content between the tags
30+
extracted_action = actions[i][start_idx + len(start_tag):end_idx].strip().lower()
31+
32+
actions[i] = extracted_action
33+
valids[i] = 1
34+
35+
except:
36+
# randomly choose an action from the action list if illegal
37+
actions[i] = actions[i][-20:]
38+
39+
# check <think>...</think>
40+
think_start_idx = original_str.find("<think>")
41+
think_end_idx = original_str.find("</think>")
42+
if think_start_idx == -1 or think_end_idx == -1:
43+
valids[i] = 0
44+
45+
# check if contains any Chinese characters
46+
if re.search(r'[\u4e00-\u9fff]', original_str):
47+
valids[i] = 0
48+
49+
return actions, valids
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
*.ipynb*
2+
*.pyc
3+
*.swp
4+
5+
.DS_Store
6+
.idea/
7+
.pytest_cache/
8+
.vscode/
9+
10+
__pycache__/
11+
data/
12+
search_engine/indexes*
13+
search_engine/resources*
14+
transfer/flagged
15+
user_session_logs/
16+
17+
18+
*_err_*.log
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2023 Princeton Natural Language Processing
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

0 commit comments

Comments
 (0)