1
+ """
2
+ Regression test for extract_aigrant_companies functionality.
3
+
4
+ This test verifies that data extraction works correctly by extracting
5
+ companies that received AI grants along with their batch numbers,
6
+ based on the TypeScript extract_aigrant_companies evaluation.
7
+ """
8
+
9
+ import os
10
+ import pytest
11
+ import pytest_asyncio
12
+ from pydantic import BaseModel , Field
13
+ from typing import List
14
+
15
+ from stagehand import Stagehand , StagehandConfig
16
+ from stagehand .schemas import ExtractOptions
17
+
18
+
19
+ class Company (BaseModel ):
20
+ company : str = Field (..., description = "The name of the company" )
21
+ batch : str = Field (..., description = "The batch number of the grant" )
22
+
23
+
24
+ class Companies (BaseModel ):
25
+ companies : List [Company ] = Field (..., description = "List of companies that received AI grants" )
26
+
27
+
28
+ class TestExtractAigrantCompanies :
29
+ """Regression test for extract_aigrant_companies functionality"""
30
+
31
+ @pytest .fixture (scope = "class" )
32
+ def local_config (self ):
33
+ """Configuration for LOCAL mode testing"""
34
+ return StagehandConfig (
35
+ env = "LOCAL" ,
36
+ model_name = "gpt-4o-mini" ,
37
+ headless = True ,
38
+ verbose = 1 ,
39
+ dom_settle_timeout_ms = 2000 ,
40
+ model_client_options = {"apiKey" : os .getenv ("MODEL_API_KEY" ) or os .getenv ("OPENAI_API_KEY" )},
41
+ )
42
+
43
+ @pytest .fixture (scope = "class" )
44
+ def browserbase_config (self ):
45
+ """Configuration for BROWSERBASE mode testing"""
46
+ return StagehandConfig (
47
+ env = "BROWSERBASE" ,
48
+ api_key = os .getenv ("BROWSERBASE_API_KEY" ),
49
+ project_id = os .getenv ("BROWSERBASE_PROJECT_ID" ),
50
+ model_name = "gpt-4o" ,
51
+ headless = False ,
52
+ verbose = 2 ,
53
+ model_client_options = {"apiKey" : os .getenv ("MODEL_API_KEY" ) or os .getenv ("OPENAI_API_KEY" )},
54
+ )
55
+
56
+ @pytest_asyncio .fixture
57
+ async def local_stagehand (self , local_config ):
58
+ """Create a Stagehand instance for LOCAL testing"""
59
+ stagehand = Stagehand (config = local_config )
60
+ await stagehand .init ()
61
+ yield stagehand
62
+ await stagehand .close ()
63
+
64
+ @pytest_asyncio .fixture
65
+ async def browserbase_stagehand (self , browserbase_config ):
66
+ """Create a Stagehand instance for BROWSERBASE testing"""
67
+ if not (os .getenv ("BROWSERBASE_API_KEY" ) and os .getenv ("BROWSERBASE_PROJECT_ID" )):
68
+ pytest .skip ("Browserbase credentials not available" )
69
+
70
+ stagehand = Stagehand (config = browserbase_config )
71
+ await stagehand .init ()
72
+ yield stagehand
73
+ await stagehand .close ()
74
+
75
+ @pytest .mark .asyncio
76
+ @pytest .mark .regression
77
+ @pytest .mark .local
78
+ async def test_extract_aigrant_companies_local (self , local_stagehand ):
79
+ """
80
+ Regression test: extract_aigrant_companies
81
+
82
+ Mirrors the TypeScript extract_aigrant_companies evaluation:
83
+ - Navigate to AI grant companies test site
84
+ - Extract all companies that received AI grants with their batch numbers
85
+ - Verify total count is 91
86
+ - Verify first company is "Goodfire" in batch "4"
87
+ - Verify last company is "Forefront" in batch "1"
88
+ """
89
+ stagehand = local_stagehand
90
+
91
+ await stagehand .page .goto ("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/" )
92
+
93
+ # Extract all companies with their batch numbers
94
+ extract_options = ExtractOptions (
95
+ instruction = (
96
+ "Extract all companies that received the AI grant and group them with their "
97
+ "batch numbers as an array of objects. Each object should contain the company "
98
+ "name and its corresponding batch number."
99
+ ),
100
+ schema_definition = Companies
101
+ )
102
+
103
+ result = await stagehand .page .extract (extract_options )
104
+
105
+ # TODO - how to unify the extract result handling between LOCAL and BROWSERBASE?
106
+
107
+ # Handle result based on the mode (LOCAL returns data directly, BROWSERBASE returns ExtractResult)
108
+ if hasattr (result , 'data' ) and result .data :
109
+ # BROWSERBASE mode format
110
+ companies_model = Companies .model_validate (result .data )
111
+ companies = companies_model .companies
112
+ else :
113
+ # LOCAL mode format - result is the Pydantic model instance
114
+ companies_model = Companies .model_validate (result .model_dump ())
115
+ companies = companies_model .companies
116
+
117
+ # Verify total count
118
+ expected_length = 91
119
+ assert len (companies ) == expected_length , (
120
+ f"Expected { expected_length } companies, but got { len (companies )} "
121
+ )
122
+
123
+ # Verify first company
124
+ expected_first_item = {
125
+ "company" : "Goodfire" ,
126
+ "batch" : "4"
127
+ }
128
+ assert len (companies ) > 0 , "No companies were extracted"
129
+ first_company = companies [0 ]
130
+ assert first_company .company == expected_first_item ["company" ], (
131
+ f"Expected first company to be '{ expected_first_item ['company' ]} ', "
132
+ f"but got '{ first_company .company } '"
133
+ )
134
+ assert first_company .batch == expected_first_item ["batch" ], (
135
+ f"Expected first company batch to be '{ expected_first_item ['batch' ]} ', "
136
+ f"but got '{ first_company .batch } '"
137
+ )
138
+
139
+ # Verify last company
140
+ expected_last_item = {
141
+ "company" : "Forefront" ,
142
+ "batch" : "1"
143
+ }
144
+ last_company = companies [- 1 ]
145
+ assert last_company .company == expected_last_item ["company" ], (
146
+ f"Expected last company to be '{ expected_last_item ['company' ]} ', "
147
+ f"but got '{ last_company .company } '"
148
+ )
149
+ assert last_company .batch == expected_last_item ["batch" ], (
150
+ f"Expected last company batch to be '{ expected_last_item ['batch' ]} ', "
151
+ f"but got '{ last_company .batch } '"
152
+ )
153
+
154
+ @pytest .mark .asyncio
155
+ @pytest .mark .regression
156
+ @pytest .mark .api
157
+ @pytest .mark .skipif (
158
+ not (os .getenv ("BROWSERBASE_API_KEY" ) and os .getenv ("BROWSERBASE_PROJECT_ID" )),
159
+ reason = "Browserbase credentials not available"
160
+ )
161
+ async def test_extract_aigrant_companies_browserbase (self , browserbase_stagehand ):
162
+ """
163
+ Regression test: extract_aigrant_companies (Browserbase)
164
+
165
+ Same test as local but running in Browserbase environment.
166
+ """
167
+ stagehand = browserbase_stagehand
168
+
169
+ await stagehand .page .goto ("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/" )
170
+
171
+ # Extract all companies with their batch numbers
172
+ extract_options = ExtractOptions (
173
+ instruction = (
174
+ "Extract all companies that received the AI grant and group them with their "
175
+ "batch numbers as an array of objects. Each object should contain the company "
176
+ "name and its corresponding batch number."
177
+ ),
178
+ schema_definition = Companies
179
+ )
180
+
181
+ result = await stagehand .page .extract (extract_options )
182
+
183
+ # TODO - how to unify the extract result handling between LOCAL and BROWSERBASE?
184
+
185
+ # Handle result based on the mode (LOCAL returns data directly, BROWSERBASE returns ExtractResult)
186
+ if hasattr (result , 'data' ) and result .data :
187
+ # BROWSERBASE mode format
188
+ companies_model = Companies .model_validate (result .data )
189
+ companies = companies_model .companies
190
+ else :
191
+ # LOCAL mode format - result is the Pydantic model instance
192
+ companies_model = Companies .model_validate (result .model_dump ())
193
+ companies = companies_model .companies
194
+
195
+ # Verify total count
196
+ expected_length = 91
197
+ assert len (companies ) == expected_length , (
198
+ f"Expected { expected_length } companies, but got { len (companies )} "
199
+ )
200
+
201
+ # Verify first company
202
+ expected_first_item = {
203
+ "company" : "Goodfire" ,
204
+ "batch" : "4"
205
+ }
206
+ assert len (companies ) > 0 , "No companies were extracted"
207
+ first_company = companies [0 ]
208
+ assert first_company .company == expected_first_item ["company" ], (
209
+ f"Expected first company to be '{ expected_first_item ['company' ]} ', "
210
+ f"but got '{ first_company .company } '"
211
+ )
212
+ assert first_company .batch == expected_first_item ["batch" ], (
213
+ f"Expected first company batch to be '{ expected_first_item ['batch' ]} ', "
214
+ f"but got '{ first_company .batch } '"
215
+ )
216
+
217
+ # Verify last company
218
+ expected_last_item = {
219
+ "company" : "Forefront" ,
220
+ "batch" : "1"
221
+ }
222
+ last_company = companies [- 1 ]
223
+ assert last_company .company == expected_last_item ["company" ], (
224
+ f"Expected last company to be '{ expected_last_item ['company' ]} ', "
225
+ f"but got '{ last_company .company } '"
226
+ )
227
+ assert last_company .batch == expected_last_item ["batch" ], (
228
+ f"Expected last company batch to be '{ expected_last_item ['batch' ]} ', "
229
+ f"but got '{ last_company .batch } '"
230
+ )
0 commit comments