Skip to content

Commit c60cd13

Browse files
authored
Merge pull request #10 from browserbase/miguel/bb-924-types-update-python-sdk
updates for BB-820: first crack
2 parents 3ad5094 + be7690d commit c60cd13

32 files changed

+2265
-453
lines changed

.env.example

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
MODEL_API_KEY = "anthropic-or-openai-api-key"
2+
BROWSERBASE_API_KEY = "browserbase-api-key"
3+
BROWSERBASE_PROJECT_ID = "browserbase-project-id"
4+
STAGEHAND_SERVER_URL = "api_url"

MANIFEST.in

Lines changed: 0 additions & 6 deletions
This file was deleted.

evals/act/google_jobs.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import asyncio
22
import traceback
3-
from typing import Optional, Any, Dict
3+
from typing import Any, Dict, Optional
4+
45
from pydantic import BaseModel
6+
57
from evals.init_stagehand import init_stagehand
68
from stagehand.schemas import ActOptions, ExtractOptions
79

@@ -49,12 +51,12 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict:
4951
- Clicking on the search button
5052
- Clicking on the first job link
5153
4. Extracting job posting details using an AI-driven extraction schema.
52-
54+
5355
The extraction schema requires:
5456
- applicationDeadline: The opening date until which applications are accepted.
5557
- minimumQualifications: An object with degree and yearsOfExperience.
5658
- preferredQualifications: An object with degree and yearsOfExperience.
57-
59+
5860
Returns a dictionary containing:
5961
- _success (bool): Whether valid job details were extracted.
6062
- jobDetails (dict): The extracted job details.
@@ -77,23 +79,25 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict:
7779

7880
try:
7981
await stagehand.page.navigate("https://www.google.com/")
80-
await asyncio.sleep(3)
82+
await asyncio.sleep(3)
8183
await stagehand.page.act(ActOptions(action="click on the about page"))
8284
await stagehand.page.act(ActOptions(action="click on the careers page"))
8385
await stagehand.page.act(ActOptions(action="input data scientist into role"))
8486
await stagehand.page.act(ActOptions(action="input new york city into location"))
8587
await stagehand.page.act(ActOptions(action="click on the search button"))
8688
await stagehand.page.act(ActOptions(action="click on the first job link"))
8789

88-
job_details = await stagehand.page.extract(ExtractOptions(
89-
instruction=(
90-
"Extract the following details from the job posting: application deadline, "
91-
"minimum qualifications (degree and years of experience), and preferred qualifications "
92-
"(degree and years of experience)"
93-
),
94-
schemaDefinition=JobDetails.model_json_schema(),
95-
useTextExtract=use_text_extract
96-
))
90+
job_details = await stagehand.page.extract(
91+
ExtractOptions(
92+
instruction=(
93+
"Extract the following details from the job posting: application deadline, "
94+
"minimum qualifications (degree and years of experience), and preferred qualifications "
95+
"(degree and years of experience)"
96+
),
97+
schemaDefinition=JobDetails.model_json_schema(),
98+
useTextExtract=use_text_extract,
99+
)
100+
)
97101

98102
valid = is_job_details_valid(job_details)
99103

@@ -104,19 +108,21 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict:
104108
"jobDetails": job_details,
105109
"debugUrl": debug_url,
106110
"sessionUrl": session_url,
107-
"logs": logger.get_logs() if hasattr(logger, "get_logs") else []
111+
"logs": logger.get_logs() if hasattr(logger, "get_logs") else [],
108112
}
109113
except Exception as e:
110114
err_message = str(e)
111115
err_trace = traceback.format_exc()
112-
logger.error({
113-
"message": "error in google_jobs function",
114-
"level": 0,
115-
"auxiliary": {
116-
"error": {"value": err_message, "type": "string"},
117-
"trace": {"value": err_trace, "type": "string"}
116+
logger.error(
117+
{
118+
"message": "error in google_jobs function",
119+
"level": 0,
120+
"auxiliary": {
121+
"error": {"value": err_message, "type": "string"},
122+
"trace": {"value": err_trace, "type": "string"},
123+
},
118124
}
119-
})
125+
)
120126

121127
await stagehand.close()
122128

@@ -125,31 +131,37 @@ async def google_jobs(model_name: str, logger, use_text_extract: bool) -> dict:
125131
"debugUrl": debug_url,
126132
"sessionUrl": session_url,
127133
"error": {"message": err_message, "trace": err_trace},
128-
"logs": logger.get_logs() if hasattr(logger, "get_logs") else []
129-
}
130-
134+
"logs": logger.get_logs() if hasattr(logger, "get_logs") else [],
135+
}
136+
137+
131138
# For quick local testing
132139
if __name__ == "__main__":
133-
import os
134140
import asyncio
135141
import logging
142+
136143
logging.basicConfig(level=logging.INFO)
137-
144+
138145
class SimpleLogger:
139146
def __init__(self):
140147
self._logs = []
148+
141149
def info(self, message):
142150
self._logs.append(message)
143151
print("INFO:", message)
152+
144153
def error(self, message):
145154
self._logs.append(message)
146155
print("ERROR:", message)
156+
147157
def get_logs(self):
148158
return self._logs
149159

150160
async def main():
151161
logger = SimpleLogger()
152-
result = await google_jobs("gpt-4o-mini", logger, use_text_extract=False) # TODO - use text extract
162+
result = await google_jobs(
163+
"gpt-4o-mini", logger, use_text_extract=False
164+
) # TODO - use text extract
153165
print("Result:", result)
154-
155-
asyncio.run(main())
166+
167+
asyncio.run(main())

evals/extract/extract_press_releases.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
11
import asyncio
2+
23
from pydantic import BaseModel
3-
from stagehand.schemas import ExtractOptions
4+
45
from evals.init_stagehand import init_stagehand
56
from evals.utils import compare_strings
7+
from stagehand.schemas import ExtractOptions
8+
69

710
# Define Pydantic models for validating press release data
811
class PressRelease(BaseModel):
912
title: str
1013
publish_date: str
1114

15+
1216
class PressReleases(BaseModel):
1317
items: list[PressRelease]
1418

19+
1520
async def extract_press_releases(model_name: str, logger, use_text_extract: bool):
1621
"""
1722
Extract press releases from the dummy press releases page using the Stagehand client.
18-
23+
1924
Args:
2025
model_name (str): Name of the AI model to use.
2126
logger: A custom logger that provides .error() and .get_logs() methods.
2227
use_text_extract (bool): Flag to control text extraction behavior.
23-
28+
2429
Returns:
2530
dict: A result object containing:
2631
- _success (bool): Whether the eval was successful.
@@ -34,12 +39,16 @@ async def extract_press_releases(model_name: str, logger, use_text_extract: bool
3439
session_url = None
3540
try:
3641
# Initialize Stagehand (mimicking the TS initStagehand)
37-
stagehand, init_response = await init_stagehand(model_name, logger, dom_settle_timeout_ms=3000)
42+
stagehand, init_response = await init_stagehand(
43+
model_name, logger, dom_settle_timeout_ms=3000
44+
)
3845
debug_url = init_response["debugUrl"]
3946
session_url = init_response["sessionUrl"]
4047

4148
# Navigate to the dummy press releases page # TODO - choose a different page
42-
await stagehand.page.navigate("https://dummy-press-releases.surge.sh/news", wait_until="networkidle")
49+
await stagehand.page.navigate(
50+
"https://dummy-press-releases.surge.sh/news", wait_until="networkidle"
51+
)
4352
# Wait for 5 seconds to ensure content has loaded
4453
await asyncio.sleep(5)
4554

@@ -49,7 +58,7 @@ async def extract_press_releases(model_name: str, logger, use_text_extract: bool
4958
ExtractOptions(
5059
instruction="extract the title and corresponding publish date of EACH AND EVERY press releases on this page. DO NOT MISS ANY PRESS RELEASES.",
5160
schemaDefinition=PressReleases.model_json_schema(),
52-
useTextExtract=use_text_extract
61+
useTextExtract=use_text_extract,
5362
)
5463
)
5564
print("Raw result:", raw_result)
@@ -73,19 +82,21 @@ async def extract_press_releases(model_name: str, logger, use_text_extract: bool
7382
expected_length = 28
7483
expected_first = PressRelease(
7584
title="UAW Region 9A Endorses Brad Lander for Mayor",
76-
publish_date="Dec 4, 2024"
85+
publish_date="Dec 4, 2024",
7786
)
7887
expected_last = PressRelease(
7988
title="Fox Sued by New York City Pension Funds Over Election Falsehoods",
80-
publish_date="Nov 12, 2023"
89+
publish_date="Nov 12, 2023",
8190
)
8291

8392
if len(items) <= expected_length:
84-
logger.error({
85-
"message": "Not enough items extracted",
86-
"expected": f"> {expected_length}",
87-
"actual": len(items)
88-
})
93+
logger.error(
94+
{
95+
"message": "Not enough items extracted",
96+
"expected": f"> {expected_length}",
97+
"actual": len(items),
98+
}
99+
)
89100
return {
90101
"_success": False,
91102
"error": "Not enough items extracted",
@@ -111,10 +122,9 @@ def is_item_match(item: PressRelease, expected: PressRelease) -> bool:
111122
await stagehand.close()
112123
return result
113124
except Exception as e:
114-
logger.error({
115-
"message": "Error in extract_press_releases function",
116-
"error": str(e)
117-
})
125+
logger.error(
126+
{"message": "Error in extract_press_releases function", "error": str(e)}
127+
)
118128
return {
119129
"_success": False,
120130
"error": str(e),
@@ -127,26 +137,33 @@ def is_item_match(item: PressRelease, expected: PressRelease) -> bool:
127137
if stagehand:
128138
await stagehand.close()
129139

140+
130141
# For quick local testing.
131142
if __name__ == "__main__":
132143
import logging
144+
133145
logging.basicConfig(level=logging.INFO)
134-
146+
135147
class SimpleLogger:
136148
def __init__(self):
137149
self._logs = []
150+
138151
def info(self, message):
139152
self._logs.append(message)
140153
print("INFO:", message)
154+
141155
def error(self, message):
142156
self._logs.append(message)
143157
print("ERROR:", message)
158+
144159
def get_logs(self):
145160
return self._logs
146161

147162
async def main():
148163
logger = SimpleLogger()
149-
result = await extract_press_releases("gpt-4o", logger, use_text_extract=False) # TODO - use text extract
164+
result = await extract_press_releases(
165+
"gpt-4o", logger, use_text_extract=False
166+
) # TODO - use text extract
150167
print("Result:", result)
151-
152-
asyncio.run(main())
168+
169+
asyncio.run(main())

evals/init_stagehand.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
11
import os
2-
import asyncio
2+
33
from stagehand import Stagehand
44
from stagehand.config import StagehandConfig
55

6+
67
async def init_stagehand(model_name: str, logger, dom_settle_timeout_ms: int = 3000):
78
"""
89
Initialize a Stagehand client with the given model name, logger, and DOM settle timeout.
9-
10+
1011
This function creates a configuration from environment variables, initializes the Stagehand client,
1112
and returns a tuple of (stagehand, init_response). The init_response contains debug and session URLs.
12-
13+
1314
Args:
1415
model_name (str): The name of the AI model to use.
1516
logger: A logger instance for logging errors and debug messages.
1617
dom_settle_timeout_ms (int): Milliseconds to wait for the DOM to settle.
17-
18+
1819
Returns:
1920
tuple: (stagehand, init_response) where init_response is a dict containing:
2021
- "debugUrl": A dict with a "value" key for the debug URL.
2122
- "sessionUrl": A dict with a "value" key for the session URL.
2223
"""
2324
# Build a Stagehand configuration object using environment variables
2425
config = StagehandConfig(
25-
env="BROWSERBASE" if os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID") else "LOCAL",
26+
env=(
27+
"BROWSERBASE"
28+
if os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")
29+
else "LOCAL"
30+
),
2631
api_key=os.getenv("BROWSERBASE_API_KEY"),
2732
project_id=os.getenv("BROWSERBASE_PROJECT_ID"),
2833
debug_dom=True,
@@ -33,7 +38,9 @@ async def init_stagehand(model_name: str, logger, dom_settle_timeout_ms: int = 3
3338
)
3439

3540
# Create a Stagehand client with the configuration; server_url is taken from environment variables.
36-
stagehand = Stagehand(config=config, server_url=os.getenv("STAGEHAND_SERVER_URL"), verbose=2)
41+
stagehand = Stagehand(
42+
config=config, server_url=os.getenv("STAGEHAND_SERVER_URL"), verbose=2
43+
)
3744
await stagehand.init()
3845

3946
# Construct the URL from the session id using the new format.
@@ -43,4 +50,4 @@ async def init_stagehand(model_name: str, logger, dom_settle_timeout_ms: int = 3
4350
url = f"wss://connect.browserbase.com?apiKey={api_key}&sessionId={stagehand.session_id}"
4451

4552
# Return both URLs as dictionaries with the "value" key.
46-
return stagehand, {"debugUrl": {"value": url}, "sessionUrl": {"value": url}}
53+
return stagehand, {"debugUrl": {"value": url}, "sessionUrl": {"value": url}}

0 commit comments

Comments
 (0)