Skip to content

Commit 1cb9463

Browse files
tests added, have some TODOs
1 parent 6800683 commit 1cb9463

File tree

1 file changed

+230
-0
lines changed

1 file changed

+230
-0
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
"""
2+
Regression test for extract_aigrant_companies functionality.
3+
4+
This test verifies that data extraction works correctly by extracting
5+
companies that received AI grants along with their batch numbers,
6+
based on the TypeScript extract_aigrant_companies evaluation.
7+
"""
8+
9+
import os
10+
import pytest
11+
import pytest_asyncio
12+
from pydantic import BaseModel, Field
13+
from typing import List
14+
15+
from stagehand import Stagehand, StagehandConfig
16+
from stagehand.schemas import ExtractOptions
17+
18+
19+
class Company(BaseModel):
20+
company: str = Field(..., description="The name of the company")
21+
batch: str = Field(..., description="The batch number of the grant")
22+
23+
24+
class Companies(BaseModel):
25+
companies: List[Company] = Field(..., description="List of companies that received AI grants")
26+
27+
28+
class TestExtractAigrantCompanies:
29+
"""Regression test for extract_aigrant_companies functionality"""
30+
31+
@pytest.fixture(scope="class")
32+
def local_config(self):
33+
"""Configuration for LOCAL mode testing"""
34+
return StagehandConfig(
35+
env="LOCAL",
36+
model_name="gpt-4o-mini",
37+
headless=True,
38+
verbose=1,
39+
dom_settle_timeout_ms=2000,
40+
model_client_options={"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")},
41+
)
42+
43+
@pytest.fixture(scope="class")
44+
def browserbase_config(self):
45+
"""Configuration for BROWSERBASE mode testing"""
46+
return StagehandConfig(
47+
env="BROWSERBASE",
48+
api_key=os.getenv("BROWSERBASE_API_KEY"),
49+
project_id=os.getenv("BROWSERBASE_PROJECT_ID"),
50+
model_name="gpt-4o",
51+
headless=False,
52+
verbose=2,
53+
model_client_options={"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")},
54+
)
55+
56+
@pytest_asyncio.fixture
57+
async def local_stagehand(self, local_config):
58+
"""Create a Stagehand instance for LOCAL testing"""
59+
stagehand = Stagehand(config=local_config)
60+
await stagehand.init()
61+
yield stagehand
62+
await stagehand.close()
63+
64+
@pytest_asyncio.fixture
65+
async def browserbase_stagehand(self, browserbase_config):
66+
"""Create a Stagehand instance for BROWSERBASE testing"""
67+
if not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")):
68+
pytest.skip("Browserbase credentials not available")
69+
70+
stagehand = Stagehand(config=browserbase_config)
71+
await stagehand.init()
72+
yield stagehand
73+
await stagehand.close()
74+
75+
@pytest.mark.asyncio
76+
@pytest.mark.regression
77+
@pytest.mark.local
78+
async def test_extract_aigrant_companies_local(self, local_stagehand):
79+
"""
80+
Regression test: extract_aigrant_companies
81+
82+
Mirrors the TypeScript extract_aigrant_companies evaluation:
83+
- Navigate to AI grant companies test site
84+
- Extract all companies that received AI grants with their batch numbers
85+
- Verify total count is 91
86+
- Verify first company is "Goodfire" in batch "4"
87+
- Verify last company is "Forefront" in batch "1"
88+
"""
89+
stagehand = local_stagehand
90+
91+
await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/")
92+
93+
# Extract all companies with their batch numbers
94+
extract_options = ExtractOptions(
95+
instruction=(
96+
"Extract all companies that received the AI grant and group them with their "
97+
"batch numbers as an array of objects. Each object should contain the company "
98+
"name and its corresponding batch number."
99+
),
100+
schema_definition=Companies
101+
)
102+
103+
result = await stagehand.page.extract(extract_options)
104+
105+
# TODO - how to unify the extract result handling between LOCAL and BROWSERBASE?
106+
107+
# Handle result based on the mode (LOCAL returns data directly, BROWSERBASE returns ExtractResult)
108+
if hasattr(result, 'data') and result.data:
109+
# BROWSERBASE mode format
110+
companies_model = Companies.model_validate(result.data)
111+
companies = companies_model.companies
112+
else:
113+
# LOCAL mode format - result is the Pydantic model instance
114+
companies_model = Companies.model_validate(result.model_dump())
115+
companies = companies_model.companies
116+
117+
# Verify total count
118+
expected_length = 91
119+
assert len(companies) == expected_length, (
120+
f"Expected {expected_length} companies, but got {len(companies)}"
121+
)
122+
123+
# Verify first company
124+
expected_first_item = {
125+
"company": "Goodfire",
126+
"batch": "4"
127+
}
128+
assert len(companies) > 0, "No companies were extracted"
129+
first_company = companies[0]
130+
assert first_company.company == expected_first_item["company"], (
131+
f"Expected first company to be '{expected_first_item['company']}', "
132+
f"but got '{first_company.company}'"
133+
)
134+
assert first_company.batch == expected_first_item["batch"], (
135+
f"Expected first company batch to be '{expected_first_item['batch']}', "
136+
f"but got '{first_company.batch}'"
137+
)
138+
139+
# Verify last company
140+
expected_last_item = {
141+
"company": "Forefront",
142+
"batch": "1"
143+
}
144+
last_company = companies[-1]
145+
assert last_company.company == expected_last_item["company"], (
146+
f"Expected last company to be '{expected_last_item['company']}', "
147+
f"but got '{last_company.company}'"
148+
)
149+
assert last_company.batch == expected_last_item["batch"], (
150+
f"Expected last company batch to be '{expected_last_item['batch']}', "
151+
f"but got '{last_company.batch}'"
152+
)
153+
154+
@pytest.mark.asyncio
155+
@pytest.mark.regression
156+
@pytest.mark.api
157+
@pytest.mark.skipif(
158+
not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")),
159+
reason="Browserbase credentials not available"
160+
)
161+
async def test_extract_aigrant_companies_browserbase(self, browserbase_stagehand):
162+
"""
163+
Regression test: extract_aigrant_companies (Browserbase)
164+
165+
Same test as local but running in Browserbase environment.
166+
"""
167+
stagehand = browserbase_stagehand
168+
169+
await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/")
170+
171+
# Extract all companies with their batch numbers
172+
extract_options = ExtractOptions(
173+
instruction=(
174+
"Extract all companies that received the AI grant and group them with their "
175+
"batch numbers as an array of objects. Each object should contain the company "
176+
"name and its corresponding batch number."
177+
),
178+
schema_definition=Companies
179+
)
180+
181+
result = await stagehand.page.extract(extract_options)
182+
183+
# TODO - how to unify the extract result handling between LOCAL and BROWSERBASE?
184+
185+
# Handle result based on the mode (LOCAL returns data directly, BROWSERBASE returns ExtractResult)
186+
if hasattr(result, 'data') and result.data:
187+
# BROWSERBASE mode format
188+
companies_model = Companies.model_validate(result.data)
189+
companies = companies_model.companies
190+
else:
191+
# LOCAL mode format - result is the Pydantic model instance
192+
companies_model = Companies.model_validate(result.model_dump())
193+
companies = companies_model.companies
194+
195+
# Verify total count
196+
expected_length = 91
197+
assert len(companies) == expected_length, (
198+
f"Expected {expected_length} companies, but got {len(companies)}"
199+
)
200+
201+
# Verify first company
202+
expected_first_item = {
203+
"company": "Goodfire",
204+
"batch": "4"
205+
}
206+
assert len(companies) > 0, "No companies were extracted"
207+
first_company = companies[0]
208+
assert first_company.company == expected_first_item["company"], (
209+
f"Expected first company to be '{expected_first_item['company']}', "
210+
f"but got '{first_company.company}'"
211+
)
212+
assert first_company.batch == expected_first_item["batch"], (
213+
f"Expected first company batch to be '{expected_first_item['batch']}', "
214+
f"but got '{first_company.batch}'"
215+
)
216+
217+
# Verify last company
218+
expected_last_item = {
219+
"company": "Forefront",
220+
"batch": "1"
221+
}
222+
last_company = companies[-1]
223+
assert last_company.company == expected_last_item["company"], (
224+
f"Expected last company to be '{expected_last_item['company']}', "
225+
f"but got '{last_company.company}'"
226+
)
227+
assert last_company.batch == expected_last_item["batch"], (
228+
f"Expected last company batch to be '{expected_last_item['batch']}', "
229+
f"but got '{last_company.batch}'"
230+
)

0 commit comments

Comments
 (0)