Skip to content

Commit 7a4f8fc

Browse files
TLSDCgasse
authored andcommitted
adding a pricy test
1 parent 379508b commit 7a4f8fc

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

tests/agents/test_vwa_agent.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import tempfile
2+
from dataclasses import dataclass
3+
from typing import Any
4+
5+
import pytest
6+
7+
from agentlab.agents.visualwebarena.agent import VWAAgent, VWAAgentArgs
8+
from agentlab.analyze import inspect_results
9+
from agentlab.experiments.study import Study
10+
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
11+
12+
mock_image = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
13+
14+
15+
class MockGoalImageAgent(VWAAgent):
16+
def obs_preprocessor(self, obs: dict) -> Any:
17+
res = super().obs_preprocessor(obs)
18+
assert isinstance(res["goal_object"], tuple)
19+
assert len(res["goal_object"]) == 1
20+
assert isinstance(res["goal_object"][0], dict)
21+
assert "type" in res["goal_object"][0]
22+
assert res["goal_object"][0]["type"] == "text"
23+
assert "text" in res["goal_object"][0]
24+
res["goal_object"] = (
25+
res["goal_object"][0],
26+
{"type": "image_url", "image_url": {"url": mock_image}},
27+
)
28+
return res
29+
30+
31+
@dataclass
32+
class MockGoalImageAgentArgs(VWAAgentArgs):
33+
agent_name: str = "debug_vwa"
34+
temperature: float = 0.1
35+
chat_model_args = None
36+
37+
def make_agent(self) -> MockGoalImageAgent:
38+
return MockGoalImageAgent(
39+
chat_model_args=self.chat_model_args,
40+
n_retry=3,
41+
)
42+
43+
44+
@pytest.mark.pricy
45+
def test_mock_goal_image_agent():
46+
47+
with tempfile.TemporaryDirectory() as tmp_dir:
48+
study = Study(
49+
[
50+
MockGoalImageAgentArgs(
51+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"]
52+
)
53+
],
54+
benchmark="miniwob_tiny_test",
55+
)
56+
study.run(n_jobs=1, parallel_backend="sequential")
57+
58+
results_df = inspect_results.load_result_df(study.dir, progress_fn=None)
59+
60+
for row in results_df.iterrows():
61+
if row[1].err_msg:
62+
print(row[1].err_msg)
63+
print(row[1].stack_trace)
64+
65+
assert len(results_df) == len(study.exp_args_list)
66+
summary = inspect_results.summarize_study(results_df)
67+
print(summary)
68+
assert len(summary) == 1
69+
reward = summary.avg_reward.iloc[0]
70+
assert reward == 1.0
71+
72+
73+
if __name__ == "__main__":
74+
test_mock_goal_image_agent()

0 commit comments

Comments
 (0)