Skip to content

Commit d2f8821

Browse files
authored
Add type information for crawler.py (#738)
Added type information to `crawler.py` to make it safer to use and understand.
1 parent a808974 commit d2f8821

File tree

2 files changed

+108
-82
lines changed

2 files changed

+108
-82
lines changed

docs/modules/agents/implementations/natbot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33

44
from langchain.chains.natbot.base import NatBotChain
5-
from langchain.chains.natbot.crawler import Crawler # type: ignore
5+
from langchain.chains.natbot.crawler import Crawler
66

77

88
def run_cmd(cmd: str, _crawler: Crawler) -> None:

langchain/chains/natbot/crawler.py

Lines changed: 107 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
11
# flake8: noqa
2-
# type: ignore
32
import time
43
from sys import platform
5-
6-
black_listed_elements = {
4+
from typing import (
5+
TYPE_CHECKING,
6+
Any,
7+
Dict,
8+
Iterable,
9+
List,
10+
Optional,
11+
Set,
12+
Tuple,
13+
TypedDict,
14+
Union,
15+
)
16+
17+
if TYPE_CHECKING:
18+
from playwright.sync_api import Browser, CDPSession, Page, sync_playwright
19+
20+
black_listed_elements: Set[str] = {
721
"html",
822
"head",
923
"title",
@@ -19,25 +33,42 @@
1933
}
2034

2135

36+
class ElementInViewPort(TypedDict):
37+
node_index: str
38+
backend_node_id: int
39+
node_name: Optional[str]
40+
node_value: Optional[str]
41+
node_meta: List[str]
42+
is_clickable: bool
43+
origin_x: int
44+
origin_y: int
45+
center_x: int
46+
center_y: int
47+
48+
2249
class Crawler:
23-
def __init__(self):
50+
def __init__(self) -> None:
2451
try:
2552
from playwright.sync_api import sync_playwright
2653
except ImportError:
2754
raise ValueError(
2855
"Could not import playwright python package. "
2956
"Please it install it with `pip install playwright`."
3057
)
31-
self.browser = sync_playwright().start().chromium.launch(headless=False)
32-
self.page = self.browser.new_page()
58+
self.browser: Browser = (
59+
sync_playwright().start().chromium.launch(headless=False)
60+
)
61+
self.page: Page = self.browser.new_page()
3362
self.page.set_viewport_size({"width": 1280, "height": 1080})
63+
self.page_element_buffer: Dict[int, ElementInViewPort]
64+
self.client: CDPSession
3465

35-
def go_to_page(self, url):
66+
def go_to_page(self, url: str) -> None:
3667
self.page.goto(url=url if "://" in url else "http://" + url)
3768
self.client = self.page.context.new_cdp_session(self.page)
3869
self.page_element_buffer = {}
3970

40-
def scroll(self, direction):
71+
def scroll(self, direction: str) -> None:
4172
if direction == "up":
4273
self.page.evaluate(
4374
"(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;"
@@ -47,7 +78,7 @@ def scroll(self, direction):
4778
"(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;"
4879
)
4980

50-
def click(self, id):
81+
def click(self, id: Union[str, int]) -> None:
5182
# Inject javascript into the page which removes the target= attribute from all links
5283
js = """
5384
links = document.getElementsByTagName("a");
@@ -59,41 +90,37 @@ def click(self, id):
5990

6091
element = self.page_element_buffer.get(int(id))
6192
if element:
62-
x = element.get("center_x")
63-
y = element.get("center_y")
93+
x: float = element["center_x"]
94+
y: float = element["center_y"]
6495

6596
self.page.mouse.click(x, y)
6697
else:
6798
print("Could not find element")
6899

69-
def type(self, id, text):
100+
def type(self, id: Union[str, int], text: str) -> None:
70101
self.click(id)
71102
self.page.keyboard.type(text)
72103

73-
def enter(self):
104+
def enter(self) -> None:
74105
self.page.keyboard.press("Enter")
75106

76-
def crawl(self):
107+
def crawl(self) -> List[str]:
77108
page = self.page
78109
page_element_buffer = self.page_element_buffer
79110
start = time.time()
80111

81112
page_state_as_text = []
82113

83-
device_pixel_ratio = page.evaluate("window.devicePixelRatio")
114+
device_pixel_ratio: float = page.evaluate("window.devicePixelRatio")
84115
if platform == "darwin" and device_pixel_ratio == 1: # lies
85116
device_pixel_ratio = 2
86117

87-
win_scroll_x = page.evaluate("window.scrollX")
88-
win_scroll_y = page.evaluate("window.scrollY")
89-
win_upper_bound = page.evaluate("window.pageYOffset")
90-
win_left_bound = page.evaluate("window.pageXOffset")
91-
win_width = page.evaluate("window.screen.width")
92-
win_height = page.evaluate("window.screen.height")
93-
win_right_bound = win_left_bound + win_width
94-
win_lower_bound = win_upper_bound + win_height
95-
document_offset_height = page.evaluate("document.body.offsetHeight")
96-
document_scroll_height = page.evaluate("document.body.scrollHeight")
118+
win_upper_bound: float = page.evaluate("window.pageYOffset")
119+
win_left_bound: float = page.evaluate("window.pageXOffset")
120+
win_width: float = page.evaluate("window.screen.width")
121+
win_height: float = page.evaluate("window.screen.height")
122+
win_right_bound: float = win_left_bound + win_width
123+
win_lower_bound: float = win_upper_bound + win_height
97124

98125
# percentage_progress_start = (win_upper_bound / document_scroll_height) * 100
99126
# percentage_progress_end = (
@@ -116,40 +143,35 @@ def crawl(self):
116143
"DOMSnapshot.captureSnapshot",
117144
{"computedStyles": [], "includeDOMRects": True, "includePaintOrder": True},
118145
)
119-
strings = tree["strings"]
120-
document = tree["documents"][0]
121-
nodes = document["nodes"]
122-
backend_node_id = nodes["backendNodeId"]
123-
attributes = nodes["attributes"]
124-
node_value = nodes["nodeValue"]
125-
parent = nodes["parentIndex"]
126-
node_types = nodes["nodeType"]
127-
node_names = nodes["nodeName"]
128-
is_clickable = set(nodes["isClickable"]["index"])
129-
130-
text_value = nodes["textValue"]
131-
text_value_index = text_value["index"]
132-
text_value_values = text_value["value"]
133-
134-
input_value = nodes["inputValue"]
135-
input_value_index = input_value["index"]
136-
input_value_values = input_value["value"]
137-
138-
input_checked = nodes["inputChecked"]
139-
layout = document["layout"]
140-
layout_node_index = layout["nodeIndex"]
141-
bounds = layout["bounds"]
142-
143-
cursor = 0
144-
html_elements_text = []
145-
146-
child_nodes = {}
147-
elements_in_view_port = []
148-
149-
anchor_ancestry = {"-1": (False, None)}
150-
button_ancestry = {"-1": (False, None)}
151-
152-
def convert_name(node_name, has_click_handler):
146+
strings: Dict[int, str] = tree["strings"]
147+
document: Dict[str, Any] = tree["documents"][0]
148+
nodes: Dict[str, Any] = document["nodes"]
149+
backend_node_id: Dict[int, int] = nodes["backendNodeId"]
150+
attributes: Dict[int, Dict[int, Any]] = nodes["attributes"]
151+
node_value: Dict[int, int] = nodes["nodeValue"]
152+
parent: Dict[int, int] = nodes["parentIndex"]
153+
node_names: Dict[int, int] = nodes["nodeName"]
154+
is_clickable: Set[int] = set(nodes["isClickable"]["index"])
155+
156+
input_value: Dict[str, Any] = nodes["inputValue"]
157+
input_value_index: List[int] = input_value["index"]
158+
input_value_values: List[int] = input_value["value"]
159+
160+
layout: Dict[str, Any] = document["layout"]
161+
layout_node_index: List[int] = layout["nodeIndex"]
162+
bounds: Dict[int, List[float]] = layout["bounds"]
163+
164+
cursor: int = 0
165+
166+
child_nodes: Dict[str, List[Dict[str, Any]]] = {}
167+
elements_in_view_port: List[ElementInViewPort] = []
168+
169+
anchor_ancestry: Dict[str, Tuple[bool, Optional[int]]] = {"-1": (False, None)}
170+
button_ancestry: Dict[str, Tuple[bool, Optional[int]]] = {"-1": (False, None)}
171+
172+
def convert_name(
173+
node_name: Optional[str], has_click_handler: Optional[bool]
174+
) -> str:
153175
if node_name == "a":
154176
return "link"
155177
if node_name == "input":
@@ -163,7 +185,9 @@ def convert_name(node_name, has_click_handler):
163185
else:
164186
return "text"
165187

166-
def find_attributes(attributes, keys):
188+
def find_attributes(
189+
attributes: Dict[int, Any], keys: List[str]
190+
) -> Dict[str, str]:
167191
values = {}
168192

169193
for [key_index, value_index] in zip(*(iter(attributes),) * 2):
@@ -181,7 +205,13 @@ def find_attributes(attributes, keys):
181205

182206
return values
183207

184-
def add_to_hash_tree(hash_tree, tag, node_id, node_name, parent_id):
208+
def add_to_hash_tree(
209+
hash_tree: Dict[str, Tuple[bool, Optional[int]]],
210+
tag: str,
211+
node_id: int,
212+
node_name: Optional[str],
213+
parent_id: int,
214+
) -> Tuple[bool, Optional[int]]:
185215
parent_id_str = str(parent_id)
186216
if not parent_id_str in hash_tree:
187217
parent_name = strings[node_names[parent_id]].lower()
@@ -195,7 +225,7 @@ def add_to_hash_tree(hash_tree, tag, node_id, node_name, parent_id):
195225

196226
# even if the anchor is nested in another anchor, we set the "root" for all descendants to be ::Self
197227
if node_name == tag:
198-
value = (True, node_id)
228+
value: Tuple[bool, Optional[int]] = (True, node_id)
199229
elif (
200230
is_parent_desc_anchor
201231
): # reuse the parent's anchor_id (which could be much higher in the tree)
@@ -212,7 +242,7 @@ def add_to_hash_tree(hash_tree, tag, node_id, node_name, parent_id):
212242

213243
for index, node_name_index in enumerate(node_names):
214244
node_parent = parent[index]
215-
node_name = strings[node_name_index].lower()
245+
node_name: Optional[str] = strings[node_name_index].lower()
216246

217247
is_ancestor_of_anchor, anchor_id = add_to_hash_tree(
218248
anchor_ancestry, "a", index, node_name, node_parent
@@ -253,7 +283,7 @@ def add_to_hash_tree(hash_tree, tag, node_id, node_name, parent_id):
253283
if not partially_is_in_viewport:
254284
continue
255285

256-
meta_data = []
286+
meta_data: List[str] = []
257287

258288
# inefficient to grab the same set of keys for kinds of objects, but it's fine for now
259289
element_attributes = find_attributes(
@@ -274,7 +304,7 @@ def add_to_hash_tree(hash_tree, tag, node_id, node_name, parent_id):
274304
else child_nodes.setdefault(str(ancestor_node_key), [])
275305
)
276306

277-
if node_name == "#text" and ancestor_exception:
307+
if node_name == "#text" and ancestor_exception and ancestor_node:
278308
text = strings[node_value[index]]
279309
if text == "|" or text == "•":
280310
continue
@@ -289,7 +319,7 @@ def add_to_hash_tree(hash_tree, tag, node_id, node_name, parent_id):
289319
) # prevent [button ... (button)..]
290320

291321
for key in element_attributes:
292-
if ancestor_exception:
322+
if ancestor_exception and ancestor_node:
293323
ancestor_node.append(
294324
{
295325
"type": "attribute",
@@ -344,36 +374,32 @@ def add_to_hash_tree(hash_tree, tag, node_id, node_name, parent_id):
344374
for element in elements_in_view_port:
345375
node_index = element.get("node_index")
346376
node_name = element.get("node_name")
347-
node_value = element.get("node_value")
348-
is_clickable = element.get("is_clickable")
349-
origin_x = element.get("origin_x")
350-
origin_y = element.get("origin_y")
351-
center_x = element.get("center_x")
352-
center_y = element.get("center_y")
353-
meta_data = element.get("node_meta")
354-
355-
inner_text = f"{node_value} " if node_value else ""
377+
element_node_value = element.get("node_value")
378+
node_is_clickable = element.get("is_clickable")
379+
node_meta_data: Optional[List[str]] = element.get("node_meta")
380+
381+
inner_text = f"{element_node_value} " if element_node_value else ""
356382
meta = ""
357383

358384
if node_index in child_nodes:
359-
for child in child_nodes.get(node_index):
385+
for child in child_nodes[node_index]:
360386
entry_type = child.get("type")
361387
entry_value = child.get("value")
362388

363-
if entry_type == "attribute":
389+
if entry_type == "attribute" and node_meta_data:
364390
entry_key = child.get("key")
365-
meta_data.append(f'{entry_key}="{entry_value}"')
391+
node_meta_data.append(f'{entry_key}="{entry_value}"')
366392
else:
367393
inner_text += f"{entry_value} "
368394

369-
if meta_data:
370-
meta_string = " ".join(meta_data)
395+
if node_meta_data:
396+
meta_string = " ".join(node_meta_data)
371397
meta = f" {meta_string}"
372398

373399
if inner_text != "":
374400
inner_text = f"{inner_text.strip()}"
375401

376-
converted_node_name = convert_name(node_name, is_clickable)
402+
converted_node_name = convert_name(node_name, node_is_clickable)
377403

378404
# not very elegant, more like a placeholder
379405
if (

0 commit comments

Comments
 (0)