Skip to content

Commit b7b2b82

Browse files
rename data files, bug fixes
1 parent 5268a3b commit b7b2b82

File tree

11 files changed

+129
-32
lines changed

11 files changed

+129
-32
lines changed

src/browsergym/workarena/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@
6262
task,
6363
)
6464

65+
# register dynamic guidance tasks
66+
for task in ALL_WORKARENA_DYNAMIC_GUIDANCE_TASKS:
67+
register_task(
68+
task.get_task_id(),
69+
task,
70+
)
71+
6572
workarena_tasks_all = [task_class.get_task_id() for task_class in ALL_WORKARENA_TASKS]
6673
workarena_tasks_atomic = [task_class.get_task_id() for task_class in ATOMIC_TASKS]
6774

@@ -139,7 +146,7 @@ def get_all_tasks_agents(filter="l2", meta_seed=42, n_seed_l1=10, is_agent_curri
139146

140147
return all_task_tuples
141148
elif level == "dg":
142-
for task in ATOMIC_TASKS:
149+
for task in ALL_WORKARENA_DYNAMIC_GUIDANCE_TASKS:
143150
for seed in rng.randint(0, 1000, n_seed_l1):
144151
all_task_tuples.append((task, int(seed)))
145152

src/browsergym/workarena/config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,29 @@
233233
# Report date filter patch flag
234234
REPORT_PATCH_FLAG = "WORKARENA_DATE_FILTER_PATCH"
235235
REPORT_FILTER_PROPERTY = "workarena.report.filter.config"
236+
237+
238+
# Case tasks
239+
GET_CASE_STATUS_CONFIG_PATH = str(
240+
resources.files(data_files).joinpath("task_configs/get_case_status.json")
241+
)
242+
GET_CASE_RESOLUTION_NOTES_CONFIG_PATH = str(
243+
resources.files(data_files).joinpath("task_configs/get_case_resnotes.json")
244+
)
245+
CLOSE_CASE_CONFIG_PATH = str(
246+
resources.files(data_files).joinpath("task_configs/close_case.json")
247+
)
248+
FIND_ASSET_UNDER_ACCOUNT_CREATE_CASE_CONFIG_PATH = str(
249+
resources.files(data_files).joinpath("task_configs/find_asset_under_account_create_case.json")
250+
)
251+
252+
# Role tasks
253+
ASSIGN_ROLE_TO_USER_ADMIN_CONFIG_PATH = str(
254+
resources.files(data_files).joinpath("task_configs/assign_role_to_user_admin.json")
255+
)
256+
ASSIGN_ROLES_TO_USER_EXPLICIT_CONFIG_PATH = str(
257+
resources.files(data_files).joinpath("task_configs/assign_roles_to_user_explicit.json")
258+
)
259+
ASSIGN_ROLES_TO_USER_IMPLICIT_CONFIG_PATH = str(
260+
resources.files(data_files).joinpath("task_configs/assign_roles_to_user_implicit.json")
261+
)

src/browsergym/workarena/data_files/task_configs/assign-role-to-user.json renamed to src/browsergym/workarena/data_files/task_configs/assign_role_to_user_admin.json

File renamed without changes.

src/browsergym/workarena/data_files/task_configs/assign-roles-to-user-explicit.json renamed to src/browsergym/workarena/data_files/task_configs/assign_roles_to_user_explicit.json

File renamed without changes.

src/browsergym/workarena/data_files/task_configs/assign-roles-to-user-implicit.json renamed to src/browsergym/workarena/data_files/task_configs/assign_roles_to_user_implicit.json

File renamed without changes.

src/browsergym/workarena/data_files/task_configs/close-case.json renamed to src/browsergym/workarena/data_files/task_configs/close_case.json

File renamed without changes.

src/browsergym/workarena/data_files/task_configs/find-asset-under-account-create-case.json renamed to src/browsergym/workarena/data_files/task_configs/find_asset_under_account_create_case.json

File renamed without changes.

src/browsergym/workarena/data_files/task_configs/get-case-resnotes.json renamed to src/browsergym/workarena/data_files/task_configs/get_case_resnotes.json

File renamed without changes.

src/browsergym/workarena/data_files/task_configs/get-case-status.json renamed to src/browsergym/workarena/data_files/task_configs/get_case_status.json

File renamed without changes.

src/browsergym/workarena/tasks/case.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Any, Dict, List, Tuple
23

34
import playwright.sync_api
@@ -9,28 +10,37 @@
910
table_api_call,
1011
table_column_info,
1112
)
13+
from ..config import (
14+
CLOSE_CASE_CONFIG_PATH,
15+
FIND_ASSET_UNDER_ACCOUNT_CREATE_CASE_CONFIG_PATH,
16+
GET_CASE_RESOLUTION_NOTES_CONFIG_PATH,
17+
GET_CASE_STATUS_CONFIG_PATH,
18+
)
1219
from .base import AbstractServiceNowTask
1320

21+
1422
class ServiceNowCaseTask(AbstractServiceNowTask):
15-
1623

17-
def __init__(self, seed: int, config: Dict[str, Any]) -> None:
18-
super().__init__(seed)
24+
def __init__(self, seed: int, fixed_config: Dict[str, Any] = None, start_rel_url: str = "/now/nav/ui/home") -> None:
25+
super().__init__(seed, start_rel_url=start_rel_url)
1926
self.task_is_setup = False
20-
self.config = config
27+
self.config = fixed_config if fixed_config else self.random.choice(self.all_configs())
2128
self.timeout = 60000
2229

2330
def setup_goal(self, page: playwright.sync_api.Page) -> Tuple[str, dict]:
2431
goal = self.config["goal"]
2532
info = self.config
2633
return goal, info
2734

35+
def all_configs(self):
36+
raise NotImplementedError
37+
2838

2939
class GetCaseStatusTask(ServiceNowCaseTask):
3040

3141
def validate(self, page: playwright.sync_api.Page, chat_messages: List[str]) -> Tuple[float, bool, str, dict]:
3242
state = self.config["state"]
33-
if state.lower() in chat_messages[-1]["content"].lower():
43+
if state.lower() in chat_messages[-1]["message"].lower():
3444
return (
3545
1,
3646
True,
@@ -44,11 +54,14 @@ def validate(self, page: playwright.sync_api.Page, chat_messages: List[str]) ->
4454
{"message": "The state does not match."},
4555
)
4656

57+
def all_configs(self):
58+
return json.load(open(GET_CASE_STATUS_CONFIG_PATH))
59+
4760

4861
class CloseCaseTask(ServiceNowCaseTask):
49-
50-
def __init__(self, seed: int, config: Dict[str, Any]) -> None:
51-
super().__init__(seed, config)
62+
63+
def __init__(self, *args, **kwargs) -> None:
64+
super().__init__(*args, **kwargs)
5265

5366
self.initial_state = table_api_call(
5467
instance=self.instance,
@@ -61,12 +74,12 @@ def __init__(self, seed: int, config: Dict[str, Any]) -> None:
6174
)["result"][0]["state"]
6275

6376
def validate(self, page: playwright.sync_api.Page, chat_messages: List[str]) -> Tuple[float, bool, str, dict]:
64-
77+
6578
# gather info from config
6679
case_number = self.config["case_number"]
6780
resolution_code = self.config["resolution_code"]
6881
close_notes = self.config["close_notes"]
69-
82+
7083
# Query sn_customerservice_case in ServiceNow
7184
record = table_api_call(
7285
instance=self.instance,
@@ -103,7 +116,7 @@ def validate(self, page: playwright.sync_api.Page, chat_messages: List[str]) ->
103116
)
104117

105118
def teardown(self) -> None:
106-
119+
107120
# revert the state to initial_state
108121
table_api_call(
109122
instance=self.instance,
@@ -119,14 +132,17 @@ def teardown(self) -> None:
119132
},
120133
)
121134

135+
def all_configs(self):
136+
return json.load(open(CLOSE_CASE_CONFIG_PATH))
137+
122138

123139
class GetCaseResolutionNotesTask(ServiceNowCaseTask):
124-
140+
125141
def validate(self, page: playwright.sync_api.Page, chat_messages: List[str]) -> Tuple[float, bool, str, dict]:
126142
close_notes = self.config["close_notes"]
127143

128144
# check for close_notes
129-
if close_notes.lower() in chat_messages[-1]["content"].lower():
145+
if close_notes.lower() in chat_messages[-1]["message"].lower():
130146
return (
131147
1,
132148
True,
@@ -138,17 +154,21 @@ def validate(self, page: playwright.sync_api.Page, chat_messages: List[str]) ->
138154
False,
139155
"",
140156
{"message": "The close notes do not match."},
141-
)
157+
)
158+
159+
def all_configs(self):
160+
return json.load(open(GET_CASE_RESOLUTION_NOTES_CONFIG_PATH))
161+
142162

143163
class FindAssetUnderAccountCreateCaseTask(ServiceNowCaseTask):
144-
145-
def __init__(self, seed: int, config: Dict[str, Any]) -> None:
146-
super().__init__(seed, config)
164+
165+
def __init__(self, *args, **kwargs) -> None:
166+
super().__init__(*args, **kwargs)
147167
self.record_sys_id = None
148168

149169
def validate(self, page: playwright.sync_api.Page, chat_messages: List[str]) -> Tuple[float, bool, str, dict]:
150170

151-
customer_account = self.config["customer_account"]
171+
customer_account = self.config["customer_account"]
152172
assets = [elem.strip() for elem in self.config.get("assets", "").split(",")]
153173

154174
# find customer account sys id
@@ -197,7 +217,6 @@ def validate(self, page: playwright.sync_api.Page, chat_messages: List[str]) ->
197217
# check for assets
198218
# TODO: this is not the best way to do it
199219
for short_description in short_descriptions:
200-
print(short_description)
201220
if all(asset.lower() in short_description.lower() for asset in assets):
202221
return (
203222
1,
@@ -221,9 +240,13 @@ def teardown(self) -> None:
221240
table="sn_customerservice_case",
222241
)
223242

243+
def all_configs(self):
244+
return json.load(open(FIND_ASSET_UNDER_ACCOUNT_CREATE_CASE_CONFIG_PATH))
245+
246+
224247
__TASKS__ = [
225248
GetCaseStatusTask,
226249
CloseCaseTask,
227250
GetCaseResolutionNotesTask,
228251
FindAssetUnderAccountCreateCaseTask,
229-
]
252+
]

0 commit comments

Comments
 (0)