1+ """
2+ Regression test for extract_aigrant_companies functionality.
3+
4+ This test verifies that data extraction works correctly by extracting
5+ companies that received AI grants along with their batch numbers,
6+ based on the TypeScript extract_aigrant_companies evaluation.
7+ """
8+
9+ import os
10+ import pytest
11+ import pytest_asyncio
12+ from pydantic import BaseModel , Field
13+ from typing import List
14+
15+ from stagehand import Stagehand , StagehandConfig
16+ from stagehand .schemas import ExtractOptions
17+
18+
19+ class Company (BaseModel ):
20+ company : str = Field (..., description = "The name of the company" )
21+ batch : str = Field (..., description = "The batch number of the grant" )
22+
23+
24+ class Companies (BaseModel ):
25+ companies : List [Company ] = Field (..., description = "List of companies that received AI grants" )
26+
27+
28+ class TestExtractAigrantCompanies :
29+ """Regression test for extract_aigrant_companies functionality"""
30+
31+ @pytest .fixture (scope = "class" )
32+ def local_config (self ):
33+ """Configuration for LOCAL mode testing"""
34+ return StagehandConfig (
35+ env = "LOCAL" ,
36+ model_name = "gpt-4o-mini" ,
37+ headless = True ,
38+ verbose = 1 ,
39+ dom_settle_timeout_ms = 2000 ,
40+ model_client_options = {"apiKey" : os .getenv ("MODEL_API_KEY" ) or os .getenv ("OPENAI_API_KEY" )},
41+ )
42+
43+ @pytest .fixture (scope = "class" )
44+ def browserbase_config (self ):
45+ """Configuration for BROWSERBASE mode testing"""
46+ return StagehandConfig (
47+ env = "BROWSERBASE" ,
48+ api_key = os .getenv ("BROWSERBASE_API_KEY" ),
49+ project_id = os .getenv ("BROWSERBASE_PROJECT_ID" ),
50+ model_name = "gpt-4o" ,
51+ headless = False ,
52+ verbose = 2 ,
53+ model_client_options = {"apiKey" : os .getenv ("MODEL_API_KEY" ) or os .getenv ("OPENAI_API_KEY" )},
54+ )
55+
56+ @pytest_asyncio .fixture
57+ async def local_stagehand (self , local_config ):
58+ """Create a Stagehand instance for LOCAL testing"""
59+ stagehand = Stagehand (config = local_config )
60+ await stagehand .init ()
61+ yield stagehand
62+ await stagehand .close ()
63+
64+ @pytest_asyncio .fixture
65+ async def browserbase_stagehand (self , browserbase_config ):
66+ """Create a Stagehand instance for BROWSERBASE testing"""
67+ if not (os .getenv ("BROWSERBASE_API_KEY" ) and os .getenv ("BROWSERBASE_PROJECT_ID" )):
68+ pytest .skip ("Browserbase credentials not available" )
69+
70+ stagehand = Stagehand (config = browserbase_config )
71+ await stagehand .init ()
72+ yield stagehand
73+ await stagehand .close ()
74+
75+ @pytest .mark .asyncio
76+ @pytest .mark .regression
77+ @pytest .mark .local
78+ async def test_extract_aigrant_companies_local (self , local_stagehand ):
79+ """
80+ Regression test: extract_aigrant_companies
81+
82+ Mirrors the TypeScript extract_aigrant_companies evaluation:
83+ - Navigate to AI grant companies test site
84+ - Extract all companies that received AI grants with their batch numbers
85+ - Verify total count is 91
86+ - Verify first company is "Goodfire" in batch "4"
87+ - Verify last company is "Forefront" in batch "1"
88+ """
89+ stagehand = local_stagehand
90+
91+ await stagehand .page .goto ("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/" )
92+
93+ # Extract all companies with their batch numbers
94+ extract_options = ExtractOptions (
95+ instruction = (
96+ "Extract all companies that received the AI grant and group them with their "
97+ "batch numbers as an array of objects. Each object should contain the company "
98+ "name and its corresponding batch number."
99+ ),
100+ schema_definition = Companies
101+ )
102+
103+ result = await stagehand .page .extract (extract_options )
104+
105+ # TODO - how to unify the extract result handling between LOCAL and BROWSERBASE?
106+
107+ # Handle result based on the mode (LOCAL returns data directly, BROWSERBASE returns ExtractResult)
108+ if hasattr (result , 'data' ) and result .data :
109+ # BROWSERBASE mode format
110+ companies_model = Companies .model_validate (result .data )
111+ companies = companies_model .companies
112+ else :
113+ # LOCAL mode format - result is the Pydantic model instance
114+ companies_model = Companies .model_validate (result .model_dump ())
115+ companies = companies_model .companies
116+
117+ # Verify total count
118+ expected_length = 91
119+ assert len (companies ) == expected_length , (
120+ f"Expected { expected_length } companies, but got { len (companies )} "
121+ )
122+
123+ # Verify first company
124+ expected_first_item = {
125+ "company" : "Goodfire" ,
126+ "batch" : "4"
127+ }
128+ assert len (companies ) > 0 , "No companies were extracted"
129+ first_company = companies [0 ]
130+ assert first_company .company == expected_first_item ["company" ], (
131+ f"Expected first company to be '{ expected_first_item ['company' ]} ', "
132+ f"but got '{ first_company .company } '"
133+ )
134+ assert first_company .batch == expected_first_item ["batch" ], (
135+ f"Expected first company batch to be '{ expected_first_item ['batch' ]} ', "
136+ f"but got '{ first_company .batch } '"
137+ )
138+
139+ # Verify last company
140+ expected_last_item = {
141+ "company" : "Forefront" ,
142+ "batch" : "1"
143+ }
144+ last_company = companies [- 1 ]
145+ assert last_company .company == expected_last_item ["company" ], (
146+ f"Expected last company to be '{ expected_last_item ['company' ]} ', "
147+ f"but got '{ last_company .company } '"
148+ )
149+ assert last_company .batch == expected_last_item ["batch" ], (
150+ f"Expected last company batch to be '{ expected_last_item ['batch' ]} ', "
151+ f"but got '{ last_company .batch } '"
152+ )
153+
154+ @pytest .mark .asyncio
155+ @pytest .mark .regression
156+ @pytest .mark .api
157+ @pytest .mark .skipif (
158+ not (os .getenv ("BROWSERBASE_API_KEY" ) and os .getenv ("BROWSERBASE_PROJECT_ID" )),
159+ reason = "Browserbase credentials not available"
160+ )
161+ async def test_extract_aigrant_companies_browserbase (self , browserbase_stagehand ):
162+ """
163+ Regression test: extract_aigrant_companies (Browserbase)
164+
165+ Same test as local but running in Browserbase environment.
166+ """
167+ stagehand = browserbase_stagehand
168+
169+ await stagehand .page .goto ("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/" )
170+
171+ # Extract all companies with their batch numbers
172+ extract_options = ExtractOptions (
173+ instruction = (
174+ "Extract all companies that received the AI grant and group them with their "
175+ "batch numbers as an array of objects. Each object should contain the company "
176+ "name and its corresponding batch number."
177+ ),
178+ schema_definition = Companies
179+ )
180+
181+ result = await stagehand .page .extract (extract_options )
182+
183+ # TODO - how to unify the extract result handling between LOCAL and BROWSERBASE?
184+
185+ # Handle result based on the mode (LOCAL returns data directly, BROWSERBASE returns ExtractResult)
186+ if hasattr (result , 'data' ) and result .data :
187+ # BROWSERBASE mode format
188+ companies_model = Companies .model_validate (result .data )
189+ companies = companies_model .companies
190+ else :
191+ # LOCAL mode format - result is the Pydantic model instance
192+ companies_model = Companies .model_validate (result .model_dump ())
193+ companies = companies_model .companies
194+
195+ # Verify total count
196+ expected_length = 91
197+ assert len (companies ) == expected_length , (
198+ f"Expected { expected_length } companies, but got { len (companies )} "
199+ )
200+
201+ # Verify first company
202+ expected_first_item = {
203+ "company" : "Goodfire" ,
204+ "batch" : "4"
205+ }
206+ assert len (companies ) > 0 , "No companies were extracted"
207+ first_company = companies [0 ]
208+ assert first_company .company == expected_first_item ["company" ], (
209+ f"Expected first company to be '{ expected_first_item ['company' ]} ', "
210+ f"but got '{ first_company .company } '"
211+ )
212+ assert first_company .batch == expected_first_item ["batch" ], (
213+ f"Expected first company batch to be '{ expected_first_item ['batch' ]} ', "
214+ f"but got '{ first_company .batch } '"
215+ )
216+
217+ # Verify last company
218+ expected_last_item = {
219+ "company" : "Forefront" ,
220+ "batch" : "1"
221+ }
222+ last_company = companies [- 1 ]
223+ assert last_company .company == expected_last_item ["company" ], (
224+ f"Expected last company to be '{ expected_last_item ['company' ]} ', "
225+ f"but got '{ last_company .company } '"
226+ )
227+ assert last_company .batch == expected_last_item ["batch" ], (
228+ f"Expected last company batch to be '{ expected_last_item ['batch' ]} ', "
229+ f"but got '{ last_company .batch } '"
230+ )
0 commit comments