Skip to content

Commit 400b947

Browse files
Add osworld axtree preprocessing
1 parent a40aa42 commit 400b947

File tree

1 file changed

+340
-0
lines changed

1 file changed

+340
-0
lines changed
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
import io
2+
import xml.etree.ElementTree as ET
3+
from typing import Tuple, List
4+
5+
from PIL import Image, ImageDraw, ImageFont
6+
7+
8+
def find_leaf_nodes(xlm_file_str):
9+
if not xlm_file_str:
10+
return []
11+
12+
root = ET.fromstring(xlm_file_str)
13+
14+
# Recursive function to traverse the XML tree and collect leaf nodes
15+
def collect_leaf_nodes(node, leaf_nodes):
16+
# If the node has no children, it is a leaf node, add it to the list
17+
if not list(node):
18+
leaf_nodes.append(node)
19+
# If the node has children, recurse on each child
20+
for child in node:
21+
collect_leaf_nodes(child, leaf_nodes)
22+
23+
# List to hold all leaf nodes
24+
leaf_nodes = []
25+
collect_leaf_nodes(root, leaf_nodes)
26+
return leaf_nodes
27+
28+
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
29+
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
30+
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
31+
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
32+
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
33+
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
34+
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
35+
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
36+
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
37+
38+
39+
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
40+
if platform == "ubuntu":
41+
_state_ns = state_ns_ubuntu
42+
_component_ns = component_ns_ubuntu
43+
elif platform == "windows":
44+
_state_ns = state_ns_windows
45+
_component_ns = component_ns_windows
46+
else:
47+
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
48+
49+
keeps: bool = (
50+
node.tag.startswith("document")
51+
or node.tag.endswith("item")
52+
or node.tag.endswith("button")
53+
or node.tag.endswith("heading")
54+
or node.tag.endswith("label")
55+
or node.tag.endswith("scrollbar")
56+
or node.tag.endswith("searchbox")
57+
or node.tag.endswith("textbox")
58+
or node.tag.endswith("link")
59+
or node.tag.endswith("tabelement")
60+
or node.tag.endswith("textfield")
61+
or node.tag.endswith("textarea")
62+
or node.tag.endswith("menu")
63+
or node.tag
64+
in {
65+
"alert",
66+
"canvas",
67+
"check-box",
68+
"combo-box",
69+
"entry",
70+
"icon",
71+
"image",
72+
"paragraph",
73+
"scroll-bar",
74+
"section",
75+
"slider",
76+
"static",
77+
"table-cell",
78+
"terminal",
79+
"text",
80+
"netuiribbontab",
81+
"start",
82+
"trayclockwclass",
83+
"traydummysearchcontrol",
84+
"uiimage",
85+
"uiproperty",
86+
"uiribboncommandbar",
87+
}
88+
)
89+
keeps = (
90+
keeps
91+
and (
92+
platform == "ubuntu"
93+
and node.get("{{{:}}}showing".format(_state_ns), "false") == "true"
94+
and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
95+
or platform == "windows"
96+
and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
97+
)
98+
and (
99+
node.get("{{{:}}}enabled".format(_state_ns), "false") == "true"
100+
or node.get("{{{:}}}editable".format(_state_ns), "false") == "true"
101+
or node.get("{{{:}}}expandable".format(_state_ns), "false") == "true"
102+
or node.get("{{{:}}}checkable".format(_state_ns), "false") == "true"
103+
)
104+
and (
105+
node.get("name", "") != ""
106+
or node.text is not None
107+
and len(node.text) > 0
108+
or check_image
109+
and node.get("image", "false") == "true"
110+
)
111+
)
112+
113+
coordinates: Tuple[int, int] = eval(
114+
node.get("{{{:}}}screencoord".format(_component_ns), "(-1, -1)")
115+
)
116+
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(_component_ns), "(-1, -1)"))
117+
keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0
118+
return keeps
119+
120+
121+
def filter_nodes(root: ET, platform="ubuntu", check_image=False):
122+
filtered_nodes = []
123+
124+
for node in root.iter():
125+
if judge_node(node, platform, check_image):
126+
filtered_nodes.append(node)
127+
# print(ET.tostring(node, encoding="unicode"))
128+
129+
return filtered_nodes
130+
131+
132+
def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0, platform="ubuntu"):
133+
134+
if platform == "ubuntu":
135+
_state_ns = state_ns_ubuntu
136+
_component_ns = component_ns_ubuntu
137+
_value_ns = value_ns_ubuntu
138+
elif platform == "windows":
139+
_state_ns = state_ns_windows
140+
_component_ns = component_ns_windows
141+
_value_ns = value_ns_windows
142+
else:
143+
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
144+
145+
# Load the screenshot image
146+
image_stream = io.BytesIO(image_file_content)
147+
image = Image.open(image_stream)
148+
if float(down_sampling_ratio) != 1.0:
149+
image = image.resize(
150+
(int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio))
151+
)
152+
draw = ImageDraw.Draw(image)
153+
marks = []
154+
drew_nodes = []
155+
text_informations: List[str] = ["index\ttag\tname\ttext"]
156+
157+
try:
158+
# Adjust the path to the font file you have or use a default one
159+
font = ImageFont.truetype("arial.ttf", 15)
160+
except IOError:
161+
# Fallback to a basic font if the specified font can't be loaded
162+
font = ImageFont.load_default()
163+
164+
index = 1
165+
166+
# Loop over all the visible nodes and draw their bounding boxes
167+
for _node in nodes:
168+
coords_str = _node.attrib.get("{{{:}}}screencoord".format(_component_ns))
169+
size_str = _node.attrib.get("{{{:}}}size".format(_component_ns))
170+
171+
if coords_str and size_str:
172+
try:
173+
# Parse the coordinates and size from the strings
174+
coords = tuple(map(int, coords_str.strip("()").split(", ")))
175+
size = tuple(map(int, size_str.strip("()").split(", ")))
176+
177+
import copy
178+
179+
original_coords = copy.deepcopy(coords)
180+
original_size = copy.deepcopy(size)
181+
182+
if float(down_sampling_ratio) != 1.0:
183+
# Downsample the coordinates and size
184+
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
185+
size = tuple(int(s * down_sampling_ratio) for s in size)
186+
187+
# Check for negative sizes
188+
if size[0] <= 0 or size[1] <= 0:
189+
raise ValueError(f"Size must be positive, got: {size}")
190+
191+
# Calculate the bottom-right corner of the bounding box
192+
bottom_right = (coords[0] + size[0], coords[1] + size[1])
193+
194+
# Check that bottom_right > coords (x1 >= x0, y1 >= y0)
195+
if bottom_right[0] < coords[0] or bottom_right[1] < coords[1]:
196+
raise ValueError(f"Invalid coordinates or size, coords: {coords}, size: {size}")
197+
198+
# Check if the area only contains one color
199+
cropped_image = image.crop((*coords, *bottom_right))
200+
if len(set(list(cropped_image.getdata()))) == 1:
201+
continue
202+
203+
# Draw rectangle on image
204+
draw.rectangle([coords, bottom_right], outline="red", width=1)
205+
206+
# Draw index number at the bottom left of the bounding box with black background
207+
text_position = (
208+
coords[0],
209+
bottom_right[1],
210+
) # Adjust Y to be above the bottom right
211+
text_bbox: Tuple[int, int, int, int] = draw.textbbox(
212+
text_position, str(index), font=font, anchor="lb"
213+
)
214+
# offset: int = bottom_right[1]-text_bbox[3]
215+
# text_bbox = (text_bbox[0], text_bbox[1]+offset, text_bbox[2], text_bbox[3]+offset)
216+
217+
# draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black')
218+
draw.rectangle(text_bbox, fill="black")
219+
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
220+
221+
# each mark is an x, y, w, h tuple
222+
marks.append(
223+
[original_coords[0], original_coords[1], original_size[0], original_size[1]]
224+
)
225+
drew_nodes.append(_node)
226+
227+
if _node.text:
228+
node_text = (
229+
_node.text
230+
if '"' not in _node.text
231+
else '"{:}"'.format(_node.text.replace('"', '""'))
232+
)
233+
elif _node.get("{{{:}}}class".format(class_ns_windows), "").endswith(
234+
"EditWrapper"
235+
) and _node.get("{{{:}}}value".format(_value_ns)):
236+
node_text = _node.get("{{{:}}}value".format(_value_ns), "")
237+
node_text = (
238+
node_text
239+
if '"' not in node_text
240+
else '"{:}"'.format(node_text.replace('"', '""'))
241+
)
242+
else:
243+
node_text = '""'
244+
text_information: str = "{:d}\t{:}\t{:}\t{:}".format(
245+
index, _node.tag, _node.get("name", ""), node_text
246+
)
247+
text_informations.append(text_information)
248+
249+
index += 1
250+
251+
except ValueError:
252+
pass
253+
254+
output_image_stream = io.BytesIO()
255+
image.save(output_image_stream, format="PNG")
256+
image_content = output_image_stream.getvalue()
257+
258+
return marks, drew_nodes, "\n".join(text_informations), image_content
259+
260+
261+
def print_nodes_with_indent(nodes, indent=0):
262+
for node in nodes:
263+
print(" " * indent, node.tag, node.attrib)
264+
print_nodes_with_indent(node, indent + 2)
265+
266+
267+
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
268+
269+
if platform == "ubuntu":
270+
_attributes_ns = attributes_ns_ubuntu
271+
_state_ns = state_ns_ubuntu
272+
_component_ns = component_ns_ubuntu
273+
_value_ns = value_ns_ubuntu
274+
elif platform == "windows":
275+
_attributes_ns = attributes_ns_windows
276+
_state_ns = state_ns_windows
277+
_component_ns = component_ns_windows
278+
_value_ns = value_ns_windows
279+
else:
280+
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
281+
282+
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
283+
linearized_accessibility_tree = [
284+
"tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)"
285+
]
286+
287+
# Linearize the accessibility tree nodes into a table format
288+
for node in filtered_nodes:
289+
if node.text:
290+
text = (
291+
node.text if '"' not in node.text else '"{:}"'.format(node.text.replace('"', '""'))
292+
)
293+
294+
elif node.get("{{{:}}}class".format(class_ns_windows), "").endswith(
295+
"EditWrapper"
296+
) and node.get("{{{:}}}value".format(_value_ns)):
297+
node_text = node.get("{{{:}}}value".format(_value_ns), "")
298+
text = (
299+
node_text if '"' not in node_text else '"{:}"'.format(node_text.replace('"', '""'))
300+
)
301+
else:
302+
text = '""'
303+
304+
linearized_accessibility_tree.append(
305+
"{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format(
306+
node.tag,
307+
node.get("name", ""),
308+
text,
309+
(
310+
node.get("{{{:}}}class".format(_attributes_ns), "")
311+
if platform == "ubuntu"
312+
else node.get("{{{:}}}class".format(class_ns_windows), "")
313+
),
314+
node.get("{{{:}}}description".format(_attributes_ns), ""),
315+
node.get("{{{:}}}screencoord".format(_component_ns), ""),
316+
node.get("{{{:}}}size".format(_component_ns), ""),
317+
)
318+
)
319+
320+
return "\n".join(linearized_accessibility_tree)
321+
322+
323+
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
324+
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
325+
# Make tag screenshot
326+
marks, drew_nodes, element_list, tagged_screenshot = draw_bounding_boxes(nodes, screenshot)
327+
328+
return marks, drew_nodes, tagged_screenshot, element_list
329+
330+
331+
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
332+
import tiktoken
333+
334+
enc = tiktoken.encoding_for_model("gpt-4")
335+
tokens = enc.encode(linearized_accessibility_tree)
336+
if len(tokens) > max_tokens:
337+
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
338+
linearized_accessibility_tree += "[...]\n"
339+
return linearized_accessibility_tree
340+

0 commit comments

Comments
 (0)