Skip to content

Commit c927f70

Browse files
committed
subtrees implementation
1 parent ae69e4b commit c927f70

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

scrapegraphai/utils/aaa.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44
import time
55

66
class TreeNode:
7-
def __init__(self, value=None, attributes=None, children=None, parent=None):
7+
def __init__(self, value=None, attributes=None, children=None, parent=None, depth=0):
88
self.value = value
99
self.attributes = attributes if attributes is not None else {}
1010
self.children = children if children is not None else []
1111
self.parent = parent
12+
self.depth = depth
1213
self.leads_to_text = False # Initialize the flag as False
1314
self.root_path = self._compute_root_path()
1415
self.closest_fork_path = self._compute_fork_path()
1516

1617
def add_child(self, child_node):
1718
child_node.parent = self
19+
child_node.depth = self.depth + 1
1820
self.children.append(child_node)
1921
child_node.update_paths()
2022
self.update_leads_to_text() # Update this node if the child leads to text
@@ -48,9 +50,18 @@ def _compute_fork_path(self):
4850
current = current.parent
4951
path.append(current.value) # Add the fork or root node
5052
return '>'.join(reversed(path))
51-
53+
54+
def get_subtrees(self):
55+
# This method finds and returns subtrees rooted at this node and all descendant forks
56+
subtrees = []
57+
if self.is_fork:
58+
subtrees.append(Tree(root=self))
59+
for child in self.children:
60+
subtrees.extend(child.get_subtrees())
61+
return subtrees
62+
5263
def __repr__(self):
53-
return f"TreeNode(value={self.value}, leads_to_text={self.leads_to_text}, root_path={self.root_path}, closest_fork_path={self.closest_fork_path})"
64+
return f"TreeNode(value={self.value}, leads_to_text={self.leads_to_text}, depth={self.depth}, root_path={self.root_path}, closest_fork_path={self.closest_fork_path})"
5465

5566
@property
5667
def is_fork(self):
@@ -72,6 +83,10 @@ def _traverse(node):
7283
_traverse(child)
7384
_traverse(self.root)
7485

86+
def get_subtrees(self):
87+
# Retrieves all subtrees rooted at fork nodes
88+
return self.root.get_subtrees() if self.root else []
89+
7590
def __repr__(self):
7691
return f"Tree(root={self.root})"
7792

@@ -89,7 +104,6 @@ def build_dom_tree(self, soup_node, tree_node):
89104
elif isinstance(child, NavigableString):
90105
text = child.strip()
91106
if text:
92-
# Create a text node with value 'text' and the actual content under 'content' key
93107
tree_node.add_child(TreeNode(value='text', attributes={'content': text}))
94108
elif isinstance(child, Tag):
95109
new_node = TreeNode(value=child.name, attributes=child.attrs)
@@ -98,14 +112,20 @@ def build_dom_tree(self, soup_node, tree_node):
98112

99113
# Usage example:
100114

101-
loader = AsyncHtmlLoader('https://www.mymovies.it/cinema/roma/')
115+
loader = AsyncHtmlLoader('https://github.com/PeriniM')
102116
document = loader.load()
103117
html_content = document[0].page_content
104118

105119
curr_time = time.time()
106120
# Instantiate a DOMTree with HTML content
107121
dom_tree = DOMTree(html_content)
122+
subtrees = dom_tree.get_subtrees() # Retrieve subtrees rooted at fork nodes
123+
108124
print(f"Time taken to build DOM tree: {time.time() - curr_time:.2f} seconds")
109125

126+
# Optionally, traverse each subtree
127+
for subtree in subtrees:
128+
print("Subtree rooted at:", subtree.root.value)
129+
# subtree.traverse(lambda node: print(node))
110130
# Traverse the DOMTree and print each node
111131
# dom_tree.traverse(lambda node: print(node))

0 commit comments

Comments
 (0)