diff --git a/.gitignore b/.gitignore index ef75a62..bacacee 100644 --- a/.gitignore +++ b/.gitignore @@ -217,4 +217,7 @@ playground/.env !launcher/** multi_app_tasks_backup -multi_app_tasks \ No newline at end of file +multi_app_tasks + +!evaluation/AndroidWorld/android_world/env/ +!evaluation/AndroidWorld/android_world/env/**/* \ No newline at end of file diff --git a/evaluation/AndroidWorld/README.md b/evaluation/AndroidWorld/README.md index a38ca13..92075dc 100644 --- a/evaluation/AndroidWorld/README.md +++ b/evaluation/AndroidWorld/README.md @@ -24,3 +24,19 @@ Please refer to [AndroidWorld README](docs/README_AndroidWorld.md) * **\[CONSOLE\_PORT]** is the port for the agent’s console * **\[CHECKPOINT\_DIR]** is the path to the directory containing your model checkpoints * **\[GRPC\_PORT]** is the port for the gRPC service + +## Qwen3-VL Model Evaluation + +We have adapted the prompts and action space of the Qwen3-VL series models to reproduce their evaluation results. + +1. **Launch the Android emulator first (example):** + ```bash + emulator -avd AndroidWorldAVD -no-snapshot -grpc 8554 + ``` + +2. **After deploying your model API with `vLLM` (refer to [model development](../README.md#-model-development)), configure the `model_url` and `model_name`, e.g., `http://:8000/v1` and `Qwen3-VL-8B-Instruct`.** + +3. **Run the evaluation using the following script (example):** + ``` + python run.py --agent_name qwen3vl --console_port 5554 --grpc_port 8554 --perform_emulator_setup=true --qwen3vl_model_base_url model_url --qwen3vl_model_name model_name --qwen3vl_model_api_key EMPTY --checkpoint_dir runs/qwen3vl_8b_instruct + ``` \ No newline at end of file diff --git a/evaluation/AndroidWorld/android_world/agents/PROMPT.py b/evaluation/AndroidWorld/android_world/agents/PROMPT.py index a47dc11..d33765c 100644 --- a/evaluation/AndroidWorld/android_world/agents/PROMPT.py +++ b/evaluation/AndroidWorld/android_world/agents/PROMPT.py @@ -79,6 +79,14 @@ """ ) +# ========================= +# Qwen3VL tool-call prompts +# ========================= + +QWEN3VL_SYSTEM_PROMPT = "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"mobile_use\", \"description\": \"Use a touchscreen to interact with a mobile device, and take screenshots.\\n* This is an interface to a mobile device with touchscreen. You can perform actions like clicking, typing, swiping, etc.\\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions.\\n* The screen's resolution is 999x999.\\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\", \"parameters\": {\"properties\": {\"action\": {\"description\": \"The action to perform. The available actions are:\\n* `click`: Click the point on the screen with coordinate (x, y).\\n* `long_press`: Press the point on the screen with coordinate (x, y) for specified seconds.\\n* `swipe`: Swipe from the starting point with coordinate (x, y) to the end point with coordinates2 (x2, y2).\\n* `type`: Input the specified text into the activated input box.\\n* `answer`: Output the answer.\\n* `system_button`: Press the system button.\\n* `wait`: Wait specified seconds for the change to happen.\\n* `terminate`: Terminate the current task and report its completion status.\", \"enum\": [\"click\", \"long_press\", \"swipe\", \"type\", \"answer\", \"system_button\", \"wait\", \"terminate\"], \"type\": \"string\"}, \"coordinate\": {\"description\": \"(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=click`, `action=long_press`, and `action=swipe`.\", \"type\": \"array\"}, \"coordinate2\": {\"description\": \"(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=swipe`.\", \"type\": \"array\"}, \"text\": {\"description\": \"Required only by `action=type` and `action=answer`.\", \"type\": \"string\"}, \"time\": {\"description\": \"The seconds to wait. Required only by `action=long_press` and `action=wait`.\", \"type\": \"number\"}, \"button\": {\"description\": \"Back means returning to the previous interface, Home means returning to the desktop, Menu means opening the application background menu, and Enter means pressing the enter. Required only by `action=system_button`\", \"enum\": [\"Back\", \"Home\", \"Menu\", \"Enter\"], \"type\": \"string\"}, \"status\": {\"description\": \"The status of the task. Required only by `action=terminate`.\", \"type\": \"string\", \"enum\": [\"success\", \"failure\"]}}, \"required\": [\"action\"], \"type\": \"object\"}}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n\n\n# Response format\n\nResponse format for every step:\n1) Thought: one concise sentence explaining the next move (no multi-step reasoning).\n2) Action: a short imperative describing what to do in the UI.\n3) A single ... block containing only the JSON: {\"name\": , \"arguments\": }.\n\nRules:\n- Output exactly in the order: Thought, Action, .\n- Be brief: one sentence for Thought, one for Action.\n- Do not output anything else outside those three parts.\n- If finishing, use action=terminate in the tool call." + +QWEN3VL_USER_PROMPT = "The user query: {instruction}.\nTask progress (You have done the following operation on the current device): {history}.\n" + SUMMARY_PROMPT_TEMPLATE = ( PROMPT_PREFIX diff --git a/evaluation/AndroidWorld/android_world/agents/seeact_v.py b/evaluation/AndroidWorld/android_world/agents/seeact_v.py index ce35760..ba701d2 100644 --- a/evaluation/AndroidWorld/android_world/agents/seeact_v.py +++ b/evaluation/AndroidWorld/android_world/agents/seeact_v.py @@ -18,6 +18,7 @@ import time import ast +import json import numpy as np from PIL import Image from openai import OpenAI @@ -40,7 +41,13 @@ from android_world.env import interface from android_world.env import json_action from android_world.env import representation_utils -from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize + +try: + from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import ( # type: ignore + smart_resize, + ) +except Exception: # pragma: no cover + smart_resize = None # type: ignore[assignment] # Utils for Visual Grounding @@ -932,3 +939,184 @@ def _to_base64_png(image: np.ndarray) -> str: buf = BytesIO() PILImage.fromarray(image).save(buf, format="PNG") return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" + + +def _extract_action_text_qwen3vl(block: str) -> str: + """Extracts the 'Action:' line from Qwen3VL text output for step history (does not affect execution).""" + m = re.search(r"Action:\s*(.+?)(?:\n|$)", block, flags=re.S) + if not m: + return "" + text = m.group(1).strip() + # Some models wrap Action: "..." with quotes. + if text.startswith('"') and text.endswith('"'): + text = text[1:-1] + return text.replace("\n", " ") + + +def _parse_tool_call_json(block: str) -> dict[str, Any] | None: + """Parse JSON inside ....""" + m = re.search(r"\s*([\s\S]*?)\s*", block) + if not m: + return None + payload = m.group(1).strip() + try: + return json.loads(payload) + except Exception: + return None + + +class Qwen3VL(base_agent.EnvironmentInteractingAgent): + """Android GUI Agent based on Qwen3VL tool-call output (for AndroidWorld eval). + + - Input: Screenshot + instruction + history + - Output: {...} + - Execution: Map to JSONAction by qwen3vl_action_transform(...) + """ + + def __init__( + self, + env: interface.AsyncEnv, + llm: infer.MultimodalLlmWrapper, + name: str = "Qwen3VL", + wait_after_action_seconds: float = 2.0, + model_base_url: str = "http://127.0.0.1:8000/v1", + model_api_key: str = "EMPTY", + model_name: str = "", + extra_headers: dict[str, str] | None = None, + ): + super().__init__(env, name) + self.llm = llm + self.wait_after_action_seconds = wait_after_action_seconds + self.model_name = model_name + self.client = OpenAI( + api_key=model_api_key, + base_url=model_base_url, + default_headers=extra_headers, + ) + self.step_his: str = "" + self.turn_number: int = 0 + # Used to detect repeated actions (avoid infinite loops) + self.last_action: str | None = None + self.repeat_time: int = 0 + + def reset(self, go_home_on_reset: bool = False): + super().reset(go_home_on_reset) + self.env.hide_automation_ui() + self.step_his = "" + self.turn_number = 0 + self.last_action = None + self.repeat_time = 0 + + @staticmethod + def _to_base64_png(image: np.ndarray) -> str: + import base64 + from io import BytesIO + from PIL import Image as PILImage + buf = BytesIO() + PILImage.fromarray(image).save(buf, format='PNG') + return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" + + def step(self, instruction: str) -> base_agent.AgentInteractionResult: + self.turn_number += 1 + + state = self.get_post_transition_state() + screenshot = state.pixels.copy() + # To be consistent with other agents in this file: BGR->RGB (for saving/encoding) + screenshot = screenshot[:, :, ::-1] + height, width = screenshot.shape[:2] + + system_prompt = QWEN3VL_SYSTEM_PROMPT + user_prompt = QWEN3VL_USER_PROMPT.format( + instruction=instruction, history=self.step_his + ) + + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": system_prompt}], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": user_prompt}, + {"type": "image_url", "image_url": {"url": self._to_base64_png(screenshot)}}, + ], + }, + ] + + completion = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + temperature=0, + ) + response = completion.choices[0].message.content or "" + print(response) + print("=" * 50) + + tool_call = _parse_tool_call_json(response) + if not tool_call: + return base_agent.AgentInteractionResult( + True, {"summary": "No JSON found in model output.", "response": response} + ) + + op_text = _extract_action_text_qwen3vl(response) + if op_text: + self.step_his += f"Step {self.turn_number}: {op_text}; " + + # Compatible: tool_call may look like {"name":"mobile_use","arguments":{...}} + args = tool_call.get("arguments", {}) if isinstance(tool_call, dict) else {} + action_name = args.get("action", "") + try: + parsed = qwen3vl_action_transform(action_name, args, width, height) + print(parsed) + except Exception as e: + return base_agent.AgentInteractionResult( + True, + { + "summary": f"Failed to transform tool-call into action: {e}", + "response": response, + "tool_call": tool_call, + }, + ) + + # Record last_action + repeat_time (previous code had these fields but not working) + # Here, use the tool-call's arguments as the "action signature", which is more robust than checking 'terminate' in a string. + try: + action_sig = json.dumps(args, ensure_ascii=False, sort_keys=True) + except Exception: + action_sig = str(args) + if self.last_action == action_sig: + self.repeat_time += 1 + else: + self.repeat_time = 0 + self.last_action = action_sig + + try: + act = json_action.JSONAction(**parsed) + self.env.execute_action(act) + time.sleep(self.wait_after_action_seconds) + except Exception: + # continue + print("Failed to execute action:", parsed) + + if parsed.get("action_type") == "status": + return base_agent.AgentInteractionResult( + True, {"response": response, "step_history": self.step_his, "parsed": parsed} + ) + + # If repeated actions reach the threshold: terminate immediately to avoid deadlock in evaluation + if self.repeat_time >= 3: + return base_agent.AgentInteractionResult( + True, + { + "summary": "Terminated due to repeated identical actions.", + "response": response, + "step_history": self.step_his, + "parsed": parsed, + "repeat_time": self.repeat_time, + }, + ) + + return base_agent.AgentInteractionResult( + False, {"response": response, "step_history": self.step_his, "parsed": parsed} + ) diff --git a/evaluation/AndroidWorld/android_world/agents/utils.py b/evaluation/AndroidWorld/android_world/agents/utils.py index d7fa7cc..c9c020a 100644 --- a/evaluation/AndroidWorld/android_world/agents/utils.py +++ b/evaluation/AndroidWorld/android_world/agents/utils.py @@ -7,7 +7,13 @@ from android_world.env import interface from android_world.env import json_action from android_world.agents import base_agent -from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize + +try: + from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import ( # type: ignore + smart_resize, + ) +except Exception: # pragma: no cover + smart_resize = None # type: ignore[assignment] def _extract_xy(s: str) -> Tuple[float, float] | None: @@ -156,6 +162,60 @@ def action_transform(action: str, width: int, height: int) -> Dict[str, Any] | N return None +def qwen3vl_action_transform(action, arguments, width, height) -> Dict[str, Any]: + if action == "key": + return {"action_type": "wait"} + elif action == "click" or action == "left_click": + coordinate = arguments.get("coordinate", [0, 0]) + x, y = coordinate + x = x / 1000 * width + y = y / 1000 * height + return {"action_type": "click", "x": x, "y": y} + elif action == "long_press": + coordinate = arguments.get("coordinate", [0, 0]) + x, y = coordinate + x = x / 1000 * width + y = y / 1000 * height + return {"action_type": "long_press", "x": x, "y": y} + elif action == "swipe": + coordinate = arguments.get("coordinate", [0, 0]) + coordinate2 = arguments.get("coordinate2", [0, 0]) + x0, y0 = coordinate[0]/1000 * width, coordinate[1]/1000 * height + x1, y1 = coordinate2[0]/1000 * width, coordinate2[1]/1000 * height + dir_ = _dir_from_coords(x0, y0, x1, y1) + return {"action_type": "scroll", "direction": reverse_direction(dir_)} + elif action == "type": + text = arguments.get("text", "") + return {"action_type": "input_text", "text": text} + elif action == "system_button": + button = arguments.get("button", "").lower() + if button == "home": + return {"action_type": "navigate_home"} + elif button == "back": + return {"action_type": "navigate_back"} + else: + raise ValueError(f"Unknown system button: {button}") + elif action == "open": + text = arguments.get("text", "") + return {"action_type": "open_app", "app_name": text} + elif action == "wait": + return {"action_type": "wait"} + elif action == "answer": + return {"action_type": "answer", "text": arguments.get("text", "")} + elif action == "terminate": + status = arguments.get("status", "").lower() + if status == "success": + return {"action_type": "status", "goal_status": "complete"} + elif status == "failure": + return {"action_type": "status", "goal_status": "infeasible"} + else: + raise ValueError(f"Unknown terminate status: {status}") + # else: + # raise ValueError(f"Unknown action: {action}") + else: + return {'action_type': 'wait'} + + def action_coord(action): def extract_click_json(s): m = re.search( diff --git a/evaluation/AndroidWorld/android_world/env/actuation.py b/evaluation/AndroidWorld/android_world/env/actuation.py new file mode 100644 index 0000000..b33c32b --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/actuation.py @@ -0,0 +1,293 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilies for actuation.""" + +import copy +import logging +import time +from typing import Any +from android_env import env_interface +from android_world.env import adb_utils +from android_world.env import android_world_controller +from android_world.env import json_action +from android_world.env import representation_utils + + +def execute_adb_action( + action: json_action.JSONAction, + screen_elements: list[Any], # list[UIElement] + screen_size: tuple[int, int], + env: env_interface.AndroidEnvInterface, +) -> None: + """Execute an action based on a JSONAction object. + + Args: + action: JSONAction object containing the action to be executed. + screen_elements: List of UI elements on the screen. + screen_size: The (width, height) of the screen. + env: The environment to execute the action in. + """ + if action.action_type in ['click', 'double_tap', 'long_press']: + idx = action.index + x = action.x + y = action.y + + + + + + + if x is not None and y is not None: + + print(f"Correctly get into CLICK XY {x,y}") + x = round(x) + y = round(y) + + if action.action_type == 'click': + adb_utils.tap_screen(x, y, env) + elif action.action_type == 'double_tap': + adb_utils.double_tap(x, y, env) + else: + adb_utils.long_press(x, y, env) + elif idx is not None: + if idx < 0 or idx >= len(screen_elements): + raise ValueError( + f'Invalid element index: {idx}, must be between 0 and' + f' {len(screen_elements) - 1}.' + ) + element = screen_elements[idx] + if element.bbox_pixels is None: + raise ValueError('Bbox is not present on element.') + x, y = element.bbox_pixels.center + x, y = int(x), int(y) + if action.action_type == 'click': + adb_utils.tap_screen(x, y, env) + elif action.action_type == 'double_tap': + adb_utils.double_tap(x, y, env) + else: + adb_utils.long_press(x, y, env) + else: + raise ValueError(f'Invalid click action: {action}') + + + + elif action.action_type == 'input_text': + text = action.text + if text: + if action.index is not None or ( + action.x is not None and action.y is not None + ): + # First focus on enter text UI element. + click_action = copy.deepcopy(action) + click_action.action_type = 'click' + execute_adb_action(click_action, screen_elements, screen_size, env) + time.sleep(1.0) + adb_utils.type_text(text, env, timeout_sec=10) + adb_utils.press_enter_button(env) + else: + logging.warning( + 'Input_text action indicated, but no text provided. No ' + 'action will be executed.' + ) + + elif action.action_type == 'keyboard_enter': + adb_utils.press_enter_button(env) + + elif action.action_type == 'navigate_home': + adb_utils.press_home_button(env) + + elif action.action_type == 'navigate_back': + adb_utils.press_back_button(env) + + elif action.action_type == 'press_keyboard': + adb_utils.press_keyboard_generic(action.keycode, env) + + elif action.action_type == 'scroll': + + screen_width, screen_height = screen_size + if action.index: + x_min, y_min, x_max, y_max = ( + max(screen_elements[action.index].bbox_pixels.x_min, 0), + max(screen_elements[action.index].bbox_pixels.y_min, 0), + min(screen_elements[action.index].bbox_pixels.x_max, screen_width), + min(screen_elements[action.index].bbox_pixels.y_max, screen_height), + ) + else: + x_min, y_min, x_max, y_max = (0, 0, screen_width, screen_height) + + start_x, start_y = (x_min + x_max) // 2, (y_min + y_max) // 2 + direction = action.direction + if direction == 'down': + end_x, end_y = (x_min + x_max) // 2, y_min + elif direction == 'up': + end_x, end_y = (x_min + x_max) // 2, y_max + elif direction == 'right': + end_x, end_y = x_min, (y_min + y_max) // 2 + elif direction == 'left': + end_x, end_y = x_max, (y_min + y_max) // 2 + else: + print('Invalid direction') + return + command = adb_utils.generate_swipe_command( + int(start_x), int(start_y), int(end_x), int(end_y) + ) + adb_utils.issue_generic_request(command, env) + + elif action.action_type == 'swipe': # Inverse of scroll. + screen_width, screen_height = screen_size + mid_x, mid_y = 0.5 * screen_width, 0.5 * screen_height + direction = action.direction + if direction == 'down': + start_x, start_y = mid_x, 0 + end_x, end_y = mid_x, screen_height + elif direction == 'up': + start_x, start_y = mid_x, screen_height + end_x, end_y = mid_x, 0 + elif direction == 'left': + start_x, start_y = 0, mid_y + end_x, end_y = screen_width, mid_y + elif direction == 'right': + start_x, start_y = screen_width, mid_y + end_x, end_y = 0, mid_y + else: + print('Invalid direction') + return + command = adb_utils.generate_swipe_command( + int(start_x), int(start_y), int(end_x), int(end_y), 500 + ) + adb_utils.issue_generic_request(command, env) + + elif action.action_type == 'open_app': + app_name = action.app_name + if app_name: + adb_utils.launch_app(app_name, env) + else: + raise ValueError('No app name provided') + + elif action.action_type == 'wait': + time.sleep(1.0) + + elif action.action_type == 'launch_adb_activity': + if action.activity_nickname == 'app_drawer': + adb_utils.press_home_button(env) + time.sleep(1.0) + start_x, start_y = int(screen_size[0] / 2), int(screen_size[1] * 0.9) + end_x = start_x + end_y = int(0.3 * screen_size[1]) + request = adb_utils.generate_swipe_command(start_x, start_y, end_x, end_y) + adb_utils.issue_generic_request(request, env) + elif action.activity_nickname == 'quick_settings': + start_x, start_y = int(screen_size[0] / 2), 30 + end_x = start_x + end_y = int(0.3 * screen_size[1]) + request = adb_utils.generate_swipe_command( + start_x, start_y, end_x, end_y, duration_ms=10 + ) + adb_utils.issue_generic_request(request, env) + elif action.action_type == 'change_orientation': + adb_utils.change_orientation(action.orientation, env) + elif action.action_type == json_action.UNKNOWN: + print('Unknown action type; no action will be executed. Try again...') + else: + print('Invalid action type') + + +def find_and_click_element( + element_text: str, + env: android_world_controller.AndroidWorldController, + case_sensitive: bool = False, +): + """Identifies element with element_text and clicks it. + + Args: + element_text: Text of the UI element to click on. + env: The Android env instance. + case_sensitive: Whether to use case sensitivity when determining which UI + element to tap. + """ + # Find text. + action = _wait_and_find_click_element(element_text, env, case_sensitive) + + ui_elements = env.get_ui_elements() + screen_size = (0, 0) # Unused, but required. + execute_adb_action(action, ui_elements, screen_size, env) + + +def _wait_and_find_click_element( + target_text: str, + env: android_world_controller.AndroidWorldController, + case_sensitive: bool, + dist_threshold: int = 1, # Allow one character difference. +) -> json_action.JSONAction: + """Wait for the screen to update until "element_text" appears.""" + ui_elements = env.get_ui_elements() + element, distance = _find_target_element( + ui_elements, target_text, case_sensitive + ) + start = time.time() + current = time.time() + while current - start < 10: + if distance <= dist_threshold: + return json_action.JSONAction(action_type='click', index=element) + ui_elements = env.get_ui_elements() + element, distance = _find_target_element( + ui_elements, target_text, case_sensitive + ) + current = time.time() + raise ValueError(f'Target text "{target_text}" not found.') + + +def _find_target_element( + ui_elements: list[representation_utils.UIElement], + target_text: str, + case_sensitive: bool, +) -> tuple[int, int]: + """Determine the UI element with the closest match to target_text, by looking at the `text` and `content_description` of each UI element.""" + best_match_index = -1 + lowest_distance = int(1e9) + + for i, element in enumerate(ui_elements): + for attr in [element.text, element.content_description]: + if attr is not None: + if case_sensitive: + distance = _levenshtein_distance(target_text, attr) + else: + distance = _levenshtein_distance(target_text.lower(), attr.lower()) + if distance < lowest_distance: + lowest_distance = distance + best_match_index = i + + return (best_match_index, lowest_distance) + + +def _levenshtein_distance(s1: str, s2: str) -> int: + """Compute the Levenshtein distance between two strings.""" + if len(s1) < len(s2): + s1, s2 = s2, s1 + + if not s2: + return len(s1) + + previous_row = range(len(s2) + 1) + for i, c1 in enumerate(s1): + current_row = [i + 1] + for j, c2 in enumerate(s2): + insertions = previous_row[j + 1] + 1 + deletions = current_row[j] + 1 + substitutions = previous_row[j] + (c1 != c2) + current_row.append(min(insertions, deletions, substitutions)) + previous_row = current_row + + return previous_row[-1] diff --git a/evaluation/AndroidWorld/android_world/env/actuation_test.py b/evaluation/AndroidWorld/android_world/env/actuation_test.py new file mode 100644 index 0000000..f46c04e --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/actuation_test.py @@ -0,0 +1,284 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time +from unittest import mock + +from absl.testing import absltest +from android_env import env_interface +from android_world.env import actuation +from android_world.env import adb_utils +from android_world.env import android_world_controller +from android_world.env import json_action +from android_world.env import representation_utils + + +@mock.patch.object(time, 'sleep') +@mock.patch.object(actuation, '_find_target_element') +@mock.patch.object(android_world_controller, 'get_a11y_tree') +@mock.patch.object(representation_utils, 'forest_to_ui_elements') +class TestWaitAndFindClickElement(absltest.TestCase): + + def test_element_found_immediately( + self, + unused_mock_representation_utils, + unused_mock_get_a11y_tree, + mock_create, + mock_sleep, + ): + """Test when the element is found immediately.""" + mock_create.return_value = (0, 0) + mock_sleep.side_effect = [0, 1] + action = actuation._wait_and_find_click_element( + 'target', mock.MagicMock(), case_sensitive=True + ) + self.assertEqual( + action, json_action.JSONAction(action_type='click', index=0) + ) + + def test_element_not_found_within_timeout( + self, + unused_mock_representation_utils, + unused_mock_get_a11y_tree, + mock_create, + mock_sleep, + ): + """Test when the element is not found within the timeout period.""" + mock_create.return_value = (-1, float('inf')) + mock_sleep.side_effect = ( + 0, + 11, + ) # Simulating 11 seconds have passed + with self.assertRaises(ValueError): + actuation._wait_and_find_click_element( + 'target', mock.MagicMock(), case_sensitive=True + ) + + +class TestCreateReferredClickAction(absltest.TestCase): + + def test_empty_ui_elements(self): + """Test with no UI elements.""" + self.assertEqual( + actuation._find_target_element([], 'target', case_sensitive=True), + (-1, int(1e9)), + ) + + def test_single_exact_match(self): + """Test with one UI element that is an exact match.""" + ui_elements = [ + representation_utils.UIElement(text='target', content_description='') + ] + self.assertEqual( + actuation._find_target_element( + ui_elements, 'target', case_sensitive=True + ), + (0, 0), + ) + + def test_multiple_elements_with_closest_match(self): + """Test with multiple elements where one is the closest match.""" + ui_elements = [ + representation_utils.UIElement(text='targ', content_description=''), + representation_utils.UIElement(text='', content_description='targetX'), + representation_utils.UIElement(text='target', content_description=''), + ] + self.assertEqual( + actuation._find_target_element( + ui_elements, 'target', case_sensitive=True + ), + (2, 0), + ) + + def test_no_exact_match(self): + """Test with no exact matching elements.""" + ui_elements = [ + representation_utils.UIElement(text='no match', content_description='') + ] + _, distance = actuation._find_target_element( + ui_elements, 'target', case_sensitive=True + ) + self.assertGreater(distance, 0) + + +class ExecuteAdbActionTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.mock_env = mock.create_autospec(spec=env_interface.AndroidEnvInterface) + self.screen_elements = [ + representation_utils.UIElement( + bbox_pixels=representation_utils.BoundingBox( + x_min=0, x_max=50, y_min=0, y_max=60 + ) + ) + ] + self.screen_size = (100, 100) + + def test_click_by_index(self): + action = json_action.JSONAction(action_type='click', index=0) + with mock.patch.object(adb_utils, 'tap_screen') as mock_tap_screen: + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_tap_screen.assert_called_once_with(25, 30, self.mock_env) + + def test_click_by_coordinates(self): + action = json_action.JSONAction(action_type='click', x=50, y=50) + with mock.patch.object(adb_utils, 'tap_screen') as mock_tap_screen: + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_tap_screen.assert_called_once_with(50, 50, self.mock_env) + + def test_click_by_coordinate_floats(self): + action = json_action.JSONAction(action_type='click', x=50.2, y=50.3) + with mock.patch.object(adb_utils, 'tap_screen') as mock_tap_screen: + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_tap_screen.assert_called_once_with(50, 50, self.mock_env) + + def test_input_text(self): + action = json_action.JSONAction( + action_type='input_text', text='test input', x=50, y=50 + ) + click_action = copy.deepcopy(action) + click_action.action_type = 'click' + with ( + mock.patch.object(adb_utils, 'tap_screen') as mock_tap_screen, + mock.patch.object(adb_utils, 'type_text') as mock_type_text, + mock.patch.object( + adb_utils, 'press_enter_button' + ) as mock_press_enter_button, + ): + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_tap_screen.assert_called_once_with(50, 50, self.mock_env) + mock_type_text.assert_called_once_with( + 'test input', self.mock_env, timeout_sec=10 + ) + mock_press_enter_button.assert_called_once_with(self.mock_env) + + def test_scroll(self): + action = json_action.JSONAction(action_type='scroll', direction='down') + with ( + mock.patch.object( + adb_utils, 'generate_swipe_command' + ) as mock_generate_swipe_command, + mock.patch.object( + adb_utils, 'issue_generic_request' + ) as mock_issue_generic_request, + ): + mock_generate_swipe_command.return_value = 'command' + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_generate_swipe_command.assert_called_once_with(50, 50, 50, 0) + mock_issue_generic_request.assert_called_once_with( + 'command', self.mock_env + ) + + def test_swipe(self): + action = json_action.JSONAction(action_type='swipe', direction='up') + with ( + mock.patch.object( + adb_utils, 'generate_swipe_command' + ) as mock_generate_swipe_command, + mock.patch.object( + adb_utils, 'issue_generic_request' + ) as mock_issue_generic_request, + ): + mock_generate_swipe_command.return_value = 'command' + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_generate_swipe_command.assert_called_once_with(50, 100, 50, 0, 500) + mock_issue_generic_request.assert_called_once_with( + 'command', self.mock_env + ) + + def test_open_app(self): + action = json_action.JSONAction(action_type='open_app', app_name='test app') + with mock.patch.object(adb_utils, 'launch_app') as mock_launch_app: + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_launch_app.assert_called_once_with('test app', self.mock_env) + + def test_double_tap(self): + action = json_action.JSONAction(action_type='double_tap', index=0) + with mock.patch.object(adb_utils, 'double_tap') as mock_double_tap: + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_double_tap.assert_called_once_with(25, 30, self.mock_env) + + def test_long_press(self): + action = json_action.JSONAction(action_type='long_press', index=0) + with mock.patch.object(adb_utils, 'long_press') as mock_long_press: + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_long_press.assert_called_once_with(25, 30, self.mock_env) + + def test_keyboard_enter(self): + action = json_action.JSONAction(action_type='keyboard_enter') + with mock.patch.object( + adb_utils, 'press_enter_button' + ) as mock_press_enter_button: + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_press_enter_button.assert_called_once_with(self.mock_env) + + def test_navigate_home(self): + action = json_action.JSONAction(action_type='navigate_home') + with mock.patch.object( + adb_utils, 'press_home_button' + ) as mock_press_home_button: + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_press_home_button.assert_called_once_with(self.mock_env) + + def test_navigate_back(self): + action = json_action.JSONAction(action_type='navigate_back') + with mock.patch.object( + adb_utils, 'press_back_button' + ) as mock_press_back_button: + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_press_back_button.assert_called_once_with(self.mock_env) + + def test_wait(self): + action = json_action.JSONAction(action_type='wait') + with mock.patch.object(time, 'sleep') as mock_sleep: + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + mock_sleep.assert_called_once_with(1.0) + + def test_unknown_action(self): + action = json_action.JSONAction(action_type=json_action.UNKNOWN) + actuation.execute_adb_action( + action, self.screen_elements, self.screen_size, self.mock_env + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/evaluation/AndroidWorld/android_world/env/adb_utils.py b/evaluation/AndroidWorld/android_world/env/adb_utils.py new file mode 100644 index 0000000..799f214 --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/adb_utils.py @@ -0,0 +1,1663 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilties to interact with the environment using adb.""" + +import os +import re +import time +from typing import Any, Callable, Collection, Iterable, Literal, Optional, TypeVar +import unicodedata +from absl import logging +from android_env import env_interface +from android_env.components import errors +from android_env.proto import adb_pb2 +import immutabledict + +T = TypeVar('T') + +_DEFAULT_TIMEOUT_SECS = 10 + +# Maps app names to the activity that should be launched to open the app. +_PATTERN_TO_ACTIVITY = immutabledict.immutabledict({ + 'google chrome|chrome': ( + 'com.android.chrome/com.google.android.apps.chrome.Main' + ), + 'google chat': 'com.google.android.apps.dynamite/com.google.android.apps.dynamite.startup.StartUpActivity', + 'settings|system settings': 'com.android.settings/.Settings', + 'youtube|yt': 'com.google.android.youtube/com.google.android.apps.youtube.app.WatchWhileActivity', + 'google play|play store|gps': ( + 'com.android.vending/com.google.android.finsky.activities.MainActivity' + ), + 'gmail|gemail|google mail|google email|google mail client': ( + 'com.google.android.gm/.ConversationListActivityGmail' + ), + 'google maps|gmaps|maps|google map': ( + 'com.google.android.apps.maps/com.google.android.maps.MapsActivity' + ), + 'google photos|gphotos|photos|google photo|google pics|google images': 'com.google.android.apps.photos/com.google.android.apps.photos.home.HomeActivity', + 'google calendar|gcal': ( + 'com.google.android.calendar/com.android.calendar.AllInOneActivity' + ), + 'camera': 'com.android.camera2/com.android.camera.CameraLauncher', + 'audio recorder': 'com.dimowner.audiorecorder/com.dimowner.audiorecorder.app.welcome.WelcomeActivity', + 'google drive|gdrive|drive': ( + 'com.google.android.apps.docs/.drive.startup.StartupActivity' + ), + 'google keep|gkeep|keep': ( + 'com.google.android.keep/.activities.BrowseActivity' + ), + 'grubhub': ( + 'com.grubhub.android/com.grubhub.dinerapp.android.splash.SplashActivity' + ), + 'tripadvisor': 'com.tripadvisor.tripadvisor/com.tripadvisor.android.ui.launcher.LauncherActivity', + 'starbucks': 'com.starbucks.mobilecard/.main.activity.LandingPageActivity', + 'google docs|gdocs|docs': 'com.google.android.apps.docs.editors.docs/com.google.android.apps.docs.editors.homescreen.HomescreenActivity', + 'google sheets|gsheets|sheets': 'com.google.android.apps.docs.editors.sheets/com.google.android.apps.docs.editors.homescreen.HomescreenActivity', + 'google slides|gslides|slides': 'com.google.android.apps.docs.editors.slides/com.google.android.apps.docs.editors.homescreen.HomescreenActivity', + 'clock': 'com.google.android.deskclock/com.android.deskclock.DeskClock', + 'google search|google': 'com.google.android.googlequicksearchbox/com.google.android.googlequicksearchbox.SearchActivity', + 'contacts': 'com.google.android.contacts/com.android.contacts.activities.PeopleActivity', + 'facebook|fb': 'com.facebook.katana/com.facebook.katana.LoginActivity', + 'whatsapp|wa': 'com.whatsapp/com.whatsapp.Main', + 'instagram|ig': ( + 'com.instagram.android/com.instagram.mainactivity.MainActivity' + ), + 'twitter|tweet': 'com.twitter.android/com.twitter.app.main.MainActivity', + 'snapchat|sc': 'com.snapchat.android/com.snap.mushroom.MainActivity', + 'telegram|tg': 'org.telegram.messenger/org.telegram.ui.LaunchActivity', + 'linkedin': ( + 'com.linkedin.android/com.linkedin.android.authenticator.LaunchActivity' + ), + 'spotify|spot': 'com.spotify.music/com.spotify.music.MainActivity', + 'netflix': 'com.netflix.mediaclient/com.netflix.mediaclient.ui.launch.UIWebViewActivity', + 'amazon shopping|amazon|amzn': ( + 'com.amazon.mShop.android.shopping/com.amazon.mShop.home.HomeActivity' + ), + 'tiktok|tt': 'com.zhiliaoapp.musically/com.ss.android.ugc.aweme.splash.SplashActivity', + 'discord': 'com.discord/com.discord.app.AppActivity$Main', + 'reddit': 'com.reddit.frontpage/com.reddit.frontpage.MainActivity', + 'pinterest': 'com.pinterest/com.pinterest.activity.PinterestActivity', + 'android world': 'com.example.androidworld/.MainActivity', + 'files': 'com.google.android.documentsui/com.android.documentsui.files.FilesActivity', + 'markor': 'net.gsantner.markor/net.gsantner.markor.activity.MainActivity', + 'clipper': 'ca.zgrs.clipper/ca.zgrs.clipper.Main', + 'messages': 'com.google.android.apps.messaging/com.google.android.apps.messaging.ui.ConversationListActivity', + 'simple sms messenger|simple sms': 'com.simplemobiletools.smsmessenger/com.simplemobiletools.smsmessenger.activities.MainActivity', + 'dialer|phone': 'com.google.android.dialer/com.google.android.dialer.extensions.GoogleDialtactsActivity', + 'simple calendar pro|simple calendar': 'com.simplemobiletools.calendar.pro/com.simplemobiletools.calendar.pro.activities.MainActivity', + 'simple gallery pro|simple gallery': 'com.simplemobiletools.gallery.pro/com.simplemobiletools.gallery.pro.activities.MainActivity', + 'miniwob': 'com.google.androidenv.miniwob/com.google.androidenv.miniwob.app.MainActivity', + 'simple draw pro': 'com.simplemobiletools.draw.pro/com.simplemobiletools.draw.pro.activities.MainActivity', + 'pro expense|pro expense app': ( + 'com.arduia.expense/com.arduia.expense.ui.MainActivity' + ), + 'broccoli|broccoli app|broccoli recipe app|recipe app': ( + 'com.flauschcode.broccoli/com.flauschcode.broccoli.MainActivity' + ), + 'caa|caa test|context aware access': 'com.google.ccc.hosted.contextawareaccess.thirdpartyapp/.ChooserActivity', + 'osmand': 'net.osmand/net.osmand.plus.activities.MapActivity', + 'tasks|tasks app|tasks.org:': ( + 'org.tasks/com.todoroo.astrid.activity.MainActivity' + ), + 'open tracks sports tracker|activity tracker|open tracks|opentracks': ( + 'de.dennisguse.opentracks/de.dennisguse.opentracks.TrackListActivity' + ), + 'joplin|joplin app': 'net.cozic.joplin/.MainActivity', + 'vlc|vlc app|vlc player': 'org.videolan.vlc/.gui.MainActivity', + 'retro music|retro|retro player': ( + 'code.name.monkey.retromusic/.activities.MainActivity' + ), +}) +_ORIENTATIONS = { + 'portrait': '0', + 'landscape': '1', + 'portrait_reversed': '2', + 'landscape_reversed': '3', +} + +# Special app names that will trigger opening the default app. +_DEFAULT_URIS: dict[str, str] = { + 'calendar': 'content://com.android.calendar', + 'browser': 'http://', + 'contacts': 'content://contacts/people/', + 'email': 'mailto:', + 'gallery': 'content://media/external/images/media/', +} + + +def check_ok(response: adb_pb2.AdbResponse, message=None) -> None: + """Check an ADB response and raise RuntimeError if not OK. + + Args: + response: AdbResponse to check. + message: Error message to raise on non-ok response. If not specified, a + generic "ADB command failed" error message is used. + + Raises: + RuntimeError: If response status is not OK. + """ + if response.status != adb_pb2.AdbResponse.Status.OK: + if message is not None: + raise RuntimeError(message) + else: + raise RuntimeError( + f'ADB command failed with status {response.status}:' + f' {response.generic.output.decode()}.' + ) + + +def start_activity( + activity: str, + extra_args: Optional[Collection[str]], + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Issues an AdbRequest to launch the given activity. + + Args: + activity: The activity to launch in standard android_package/activity_name + format. + extra_args: Optional set of arguments to be issued with the ABD broadcast. + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + The adb response received after issuing the request. + """ + logging.info('Attempting to launch %r', activity) + response = env.execute_adb_call( + adb_pb2.AdbRequest( + start_activity=adb_pb2.AdbRequest.StartActivity( + full_activity=activity, extra_args=extra_args + ), + timeout_sec=timeout_sec, + ) + ) + if response.status != adb_pb2.AdbResponse.Status.OK: + logging.error('Failed to launch activity: %r', activity) + return response + + logging.info('Launch package output %r', response.generic.output) + return response + + +def get_current_activity( + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> tuple[Optional[str], adb_pb2.AdbResponse]: + """Returns the full activity name that is currently opened to the user. + + Args: + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + A tuple (current_activity_name, adb_response) containing the string with + the current activity or None if no current activity can be + extracted, and the adb response received after issuing the request. + """ + response = env.execute_adb_call( + adb_pb2.AdbRequest( + get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity(), + timeout_sec=timeout_sec, + ) + ) + if response.status != adb_pb2.AdbResponse.Status.OK: + logging.warning( + 'Failed to obtain visible task. error_message: %r', + response.error_message, + ) + return (None, response) + + activity = response.get_current_activity.full_activity + return (activity, response) + + +def tap_screen( + x: int, + y: int, + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Issues an AdbRequest to tap the screen at the specified point. + + Args: + x: X coordinate on the screen, in pixels. + y: Y coordinate on the screen, in pixels. + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + The adb response received after issuing the request. + """ + logging.info('Attemting to tap the screen at (%d, %d)', x, y) + response = env.execute_adb_call( + adb_pb2.AdbRequest( + tap=adb_pb2.AdbRequest.Tap(x=x, y=y), timeout_sec=timeout_sec + ) + ) + + if response.status != adb_pb2.AdbResponse.Status.OK: + logging.error('Failed to tap the screen') + + return response + + +def double_tap( + x: int, + y: int, + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Issues two AdbRequests to double tap the screen at the specified point. + + Args: + x: X coordinate on the screen, in pixels. + y: Y coordinate on the screen, in pixels. + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + The adb response received after issuing the second tap request. + """ + logging.info('Attempting to double tap the screen at (%d, %d)', x, y) + first_tap = tap_screen(x, y, env, timeout_sec=0) + second_tap = tap_screen(x, y, env, timeout_sec=timeout_sec) + logging.info('First tap: %s', first_tap) + logging.info('Second tap: %s', second_tap) + return second_tap + + +def long_press( + x: int, + y: int, + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Issues an AdbRequest to long press the screen at the specified point. + + Args: + x: X coordinate on the screen, in pixels. + y: Y coordinate on the screen, in pixels. + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + The adb response received after issuing long press. + """ + logging.info('Attempting to long press the screen at (%d, %d)', x, y) + return issue_generic_request( + ['shell', 'input', 'swipe', str(x), str(y), str(x), str(y), '1000'], + env, + timeout_sec, + ) + + +def press_home_button( + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Issues an AdbRequest to press the HOME button in the nav bar. + + Args: + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + The adb response received after issuing the request. + """ + logging.info('Attempting to press the HOME button') + response = env.execute_adb_call( + adb_pb2.AdbRequest( + press_button=adb_pb2.AdbRequest.PressButton( + button=adb_pb2.AdbRequest.PressButton.HOME + ), + timeout_sec=timeout_sec, + ) + ) + + if response.status != adb_pb2.AdbResponse.Status.OK: + logging.error('Failed to press the HOME button') + return response + + +def press_back_button( + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Issues an AdbRequest to press the BACK button in the nav bar. + + Args: + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + The adb response received after issuing the request. + """ + logging.info('Attemting to press the BACK button') + response = env.execute_adb_call( + adb_pb2.AdbRequest( + press_button=adb_pb2.AdbRequest.PressButton( + button=adb_pb2.AdbRequest.PressButton.BACK + ), + timeout_sec=timeout_sec, + ) + ) + + if response.status != adb_pb2.AdbResponse.Status.OK: + logging.error('Failed to press the BACK button') + + return response + + +def press_enter_button( + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Issues an AdbRequest to press the ENTER button in the nav bar. + + Args: + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + The adb response received after issuing the request. + """ + logging.info('Attemting to press the ENTER button') + response = env.execute_adb_call( + adb_pb2.AdbRequest( + press_button=adb_pb2.AdbRequest.PressButton( + button=adb_pb2.AdbRequest.PressButton.ENTER + ), + timeout_sec=timeout_sec, + ) + ) + + if response.status != adb_pb2.AdbResponse.Status.OK: + logging.error('Failed to press the ENTER button') + + return response + + +def press_keyboard_generic( + keycode: str, + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Issues an AdbRequest to press any button in the keyboard. + + Args: + keycode: The keycode to press. + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + The adb response received after issuing the request. + """ + logging.info('Attemting to press the keyboard button: %s', keycode) + + response = issue_generic_request( + ['shell', 'input', 'keyevent', keycode], + env, + timeout_sec, + ) + + if response.status != adb_pb2.AdbResponse.Status.OK: + logging.error('Failed to press the keyboard button: %s', keycode) + + return response + + +def _adb_text_format(text: str) -> str: + """Prepares text for use with adb.""" + to_escape = [ + '\\', + ';', + '|', + '`', + '\r', + ' ', + "'", + '"', + '&', + '<', + '>', + '(', + ')', + '#', + '$', + ] + for char in to_escape: + text = text.replace(char, '\\' + char) + normalized_text = unicodedata.normalize('NFKD', text) + return normalized_text.encode('ascii', 'ignore').decode('ascii') + + +def _split_words_and_newlines(text: str) -> Iterable[str]: + """Split lines of text into individual words and newline chars.""" + lines = text.split('\n') + for i, line in enumerate(lines): + words = line.split(' ') + for j, word in enumerate(words): + yield word + if j < len(words) - 1: + yield '%s' + if i < len(lines) - 1: + yield '\n' + + +def type_text( + text: str, + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> None: + """Issues an AdbRequest to type the specified text string word-by-word. + + It types word-by-word to fix issue where sometimes long text strings can be + typed out of order at the character level. Additionally, long strings can time + out and word-by-word fixes this, while allowing us to keep a lot timeout per + word. + + Args: + text: The text string to be typed. + env: The environment. + timeout_sec: A timeout to use for this operation. Note: For longer texts, + this should be longer as it takes longer to type. + """ + words = _split_words_and_newlines(text) + for word in words: + if word == '\n': + logging.info('Found \\n, pressing enter button.') + press_enter_button(env) + continue + formatted = _adb_text_format(word) + logging.info('Attempting to type word: %r', formatted) + response = env.execute_adb_call( + adb_pb2.AdbRequest( + input_text=adb_pb2.AdbRequest.InputText(text=formatted), + timeout_sec=timeout_sec, + ) + ) + + if response.status != adb_pb2.AdbResponse.Status.OK: + logging.error('Failed to type word: %r', formatted) + + +def issue_generic_request( + args: Collection[str] | str, + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Issues a generic adb command. + + Example: + ~~~~~~~ + + issue_generic_request(['shell', 'ls'], env) + # or + issue_generic_request('shell ls', env) + + Args: + args: Set of arguments to be issued with the ABD broadcast. Can also be a + string. + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + The adb response received after issuing the request. + """ + if isinstance(args, str): + args_str = args + args = args.split(' ') + else: + args_str = ' '.join(args) + logging.info('Issuing generic adb request: %r', args_str) + + response = env.execute_adb_call( + adb_pb2.AdbRequest( + generic=adb_pb2.AdbRequest.GenericRequest(args=args), + timeout_sec=timeout_sec, + ) + ) + if response.status != adb_pb2.AdbResponse.Status.OK: + logging.error('Failed to issue generic adb request: %r', args_str) + + return response + + +def get_adb_activity(app_name: str) -> Optional[str]: + """Get a mapping of regex patterns to ADB activities top Android apps.""" + for pattern, activity in _PATTERN_TO_ACTIVITY.items(): + if re.match(pattern.lower(), app_name.lower()): + return activity + + +def get_all_package_names( + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> list[str]: + """Returns all packages installed on the device. + + Args: + env: The AndroidEnv interface. + timeout_sec: A timeout to use for this operation. + + Returns: + A list of installed package names. + """ + response = env.execute_adb_call( + adb_pb2.AdbRequest( + package_manager=adb_pb2.AdbRequest.PackageManagerRequest( + list=adb_pb2.AdbRequest.PackageManagerRequest.List( + packages=adb_pb2.AdbRequest.PackageManagerRequest.List.Packages() + ) + ), + timeout_sec=timeout_sec, + ) + ) + if response.status != adb_pb2.AdbResponse.Status.OK: + logging.error('Failed to issue package manager request.') + + package_names = list(response.package_manager.list.items) + return package_names + + +def get_all_apps( + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> list[str]: + """Returns all apps installed on the device. + + Note: the output list will not be exhaustive as it is currently based on a + mapping we define, so any apps not included in that mapping will not be + output here. + + Args: + env: The AndroidEnv interface. + timeout_sec: A timeout to use for this operation. If not set the default + timeout will be used. + + Returns: + A list of app names. + """ + packages = get_all_package_names(env, timeout_sec) + package_to_app = { + v.split('/')[0]: k.split('|')[0] for k, v in _PATTERN_TO_ACTIVITY.items() + } + app_names = [] + for package in packages: + if package in package_to_app: + app_names.append(package_to_app[package]) + + return app_names + + +def _launch_default_app( + app_key: str, + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Launches a default application with a predefined data URI.""" + if app_key not in _DEFAULT_URIS: + raise ValueError( + f'Unrecognized app key: {app_key}. Must be one of' + f' {list(_DEFAULT_URIS.keys())}' + ) + data_uri = _DEFAULT_URIS[app_key] + adb_command = [ + 'shell', + 'am', + 'start', + '-a', + 'android.intent.action.VIEW', + '-d', + data_uri, + ] + response = issue_generic_request(adb_command, env, timeout_sec) + return response + + +def launch_app( + app_name: str, + env: env_interface.AndroidEnvInterface, +) -> Optional[str]: + """Uses regex and ADB activity to try to launch an app. + + Args: + app_name: The name of the app, as represented as a key in + _PATTERN_TO_ACTIVITY. + env: The environment. + + Returns: + The name of the app that is launched. + """ + + if app_name in _DEFAULT_URIS: + _launch_default_app(app_name, env) + return app_name + + activity = get_adb_activity(app_name) + if activity is None: + # If the app name is not in the mapping, assume it is a package name. + response = issue_generic_request( + ['shell', 'monkey', '-p', app_name, '1'], env, timeout_sec=5 + ) + logging.info('Launching app by package name, response: %r', response) + return app_name + start_activity(activity, extra_args=[], env=env, timeout_sec=5) + return app_name + + +def extract_package_name(activity: str) -> str: + """Extract the package name from the activity string.""" + return activity.split('/')[0] + + +def close_recents(env: env_interface.AndroidEnvInterface): + """Closes all recent apps.""" + response = issue_generic_request('shell dumpsys activity recents', env) + if response.status != adb_pb2.AdbResponse.Status.OK: + return + recents_ids = re.findall(r'id=(\d+)', response.generic.output.decode()) + for recents_id in recents_ids: + issue_generic_request(['shell', 'am', 'stack', 'remove', recents_id], env) + + +def close_app( + app_name: str, + env: env_interface.AndroidEnvInterface, + timeout_sec: Optional[float] = _DEFAULT_TIMEOUT_SECS, +) -> Optional[str]: + """Uses regex and ADB package name to try to directly close an app. + + Args: + app_name: The name of the app, as represented as a key in + _PATTERN_TO_ACTIVITY. + env: The environment. + timeout_sec: The timeout. + + Returns: + The app name that is closed. + """ + activity = get_adb_activity(app_name) + if activity is None: + logging.error('Failed to close app: %r', app_name) + return None + package_name = extract_package_name(activity) + issue_generic_request( + ['shell', 'am', 'force-stop', package_name], env, timeout_sec + ) + return app_name + + +def generate_swipe_command( + start_x: int, + start_y: int, + end_x: int, + end_y: int, + duration_ms: Optional[int] = None, +) -> list[str]: + """Sends a swipe action to the simulator. + + Args: + start_x: The x-coordinate of the start of the swipe. + start_y: The y-coordinate of the start of the swipe. + end_x: The x-coordinate of the end of the swipe. + end_y: The y-coordinate of the end of the swipe. + duration_ms: If given, the duration of time in milliseconds to take to + complete the swipe. This value can differentiate a swipe from a fling. + + Returns: + List of adb arguments. + """ + duration_str = str(duration_ms) if duration_ms else '' + return [ + 'shell', + 'input', + 'swipe', + str(start_x), + str(start_y), + str(end_x), + str(end_y), + duration_str, + ] + + +def send_android_intent( + command: str, + action: str, + env: env_interface.AndroidEnvInterface, + data_uri: str | None = None, + mime_type: str | None = None, + extras: dict[str, Any] | None = None, + timeout_sec: int = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Sends an intent to Android device using adb. + + This is a low-level command for sending an intent with additional parameters. + When these additional parameters are not necessary, consider instead using + `adb_utils.start_activity()` or `env.execute_adb_call()` with + `AdbRequest.StartActivity` or `AdbRequest.SendBroadcast`. + + Args: + command: Either "start" for start activity intents or "broadcast" for + broadcast intents. + action: The broadcast action (e.g. "android.intent.action.VIEW"). + env: The environment to which the broadcast is sent. + data_uri: Optional intent data URI (e.g. "content://contacts/people/1"). + mime_type: Optional mime type (e.g. "image/png"). + extras: Dictionary containing keys and values to be sent as extras. + timeout_sec: The maximum time in seconds to wait for the broadcast to + complete. + + Returns: + AdbResponse object. + """ + if command not in ['start', 'broadcast']: + raise ValueError('Intent command must be either "start" or "broadcast"') + + adb_command = ['shell', 'am', command, '-a', action] + + if data_uri: + adb_command.extend(['-d', f'"{data_uri}"']) + + if mime_type: + adb_command.extend(['-t', f'"{mime_type}"']) + + if extras: + for key, value in extras.items(): + if isinstance(value, tuple): + type_override, value = value + if type_override == 'str': + adb_command.extend(['--es', key, f'"{value}"']) + elif type_override == 'bool': + adb_command.extend(['--ez', key, f'"{value}"']) + elif type_override == 'int': + adb_command.extend(['--ei', key, f'"{value}"']) + elif type_override == 'long': # long type only available via override. + adb_command.extend(['--el', key, f'"{value}"']) + elif type_override == 'float': + adb_command.extend(['--ef', key, f'"{value}"']) + elif type_override == 'string array': + array_str = ','.join(value) + adb_command.extend(['--esa', key, f'"{array_str}"']) + elif isinstance(value, str): + adb_command.extend(['--es', key, f'"{value}"']) + elif isinstance(value, bool): + adb_command.extend(['--ez', key, f'"{value}"']) + elif isinstance(value, int): + adb_command.extend(['--ei', key, f'"{value}"']) + # long type only available via override above. + elif isinstance(value, float): + adb_command.extend(['--ef', key, f'"{value}"']) + elif isinstance(value, list): + array_str = ','.join(value) + adb_command.extend(['--esa', key, f'"{array_str}"']) + else: + raise ValueError(f'Unrecognized extra type for {key}') + + return issue_generic_request(adb_command, env, timeout_sec) + + +def get_api_level(env: env_interface.AndroidEnvInterface) -> int: + """Gets the API level of the device. + + Args: + env: The environment. + + Returns: + The API level. + + Raises: + RuntimeError: If adb command does not successfully execute. + """ + version = issue_generic_request( + ['shell', 'getprop ro.build.version.sdk'], env + ) + if version.status != adb_pb2.AdbResponse.Status.OK: + raise RuntimeError('Failed to get API level.') + return int(version.generic.output) + + +def _toggle_svc( + service: str, + on_or_off: Literal['on', 'off'], + env: env_interface.AndroidEnvInterface, +) -> adb_pb2.AdbResponse: + """Toggles a system service on or off using svc. + + Args: + service: The name of the service to toggle. + on_or_off: The state to set ('on' or 'off'). + env: The Android environment. + + Returns: + adb status. + + Raises: + ValueError: If invalid on_or_off is provided. + """ + if on_or_off not in ('on', 'off'): + raise ValueError('Must be one of on or off.') + + cmd = 'enable' if on_or_off == 'on' else 'disable' + return issue_generic_request(['shell', 'svc', service, cmd], env) + + +def toggle_wifi( + env: env_interface.AndroidEnvInterface, on_or_off: Literal['on', 'off'] +) -> adb_pb2.AdbResponse: + """Toggles wifi on or off. + + Args: + env: The Android environment. + on_or_off: Whether to turn it on or off. + + Returns: + adb status. + """ + return _toggle_svc('wifi', on_or_off, env) + + +def toggle_bluetooth( + env: env_interface.AndroidEnvInterface, on_or_off: Literal['on', 'off'] +) -> adb_pb2.AdbResponse: + """Toggles Bluetooth on or off. + + Args: + env: The Android environment. + on_or_off: Whether to turn it on or off. + + Returns: + adb status. + """ + return _toggle_svc('bluetooth', on_or_off, env) + + +def set_brightness( + max_or_min: str, env: env_interface.AndroidEnvInterface +) -> adb_pb2.AdbResponse: + """Sets screen brightness to maximum or minimum. + + Args: + max_or_min: Whether to set it to maximum or minimum. + env: The environment. + + Returns: + The adb status. + + Raises: + ValueError: If invalid max_or_min is provided. + """ + if max_or_min not in ('max', 'min'): + raise ValueError('Must be one of max or min.') + + brightness_level = '255' if max_or_min == 'max' else '1' + + return issue_generic_request( + [ + 'shell', + 'settings', + 'put', + 'system', + 'screen_brightness', + brightness_level, + ], + env, + ) + + +def clear_app_data( + package_name: str, env: env_interface.AndroidEnvInterface +) -> adb_pb2.AdbResponse: + """Clears all data for a given package. + + Args: + package_name: The package name of the app whose data is to be cleared. + env: The environment. + + Returns: + adb status. + """ + try: + return issue_generic_request(['shell', 'pm', 'clear', package_name], env) + except errors.AdbControllerError as exc: + raise errors.AdbControllerError( + f'Failed to clear app data for package {package_name}. Is the app' + ' installed?' + ) from exc + + +def toggle_airplane_mode( + on_or_off: Literal['on', 'off'], env: env_interface.AndroidEnvInterface +) -> adb_pb2.AdbResponse: + """Toggles airplane mode on or off. + + Args: + on_or_off: Whether to turn it on or off. + env: The Android environment. + + Returns: + adb status. + + Raises: + ValueError: If invalid on_or_off is provided. + """ + if on_or_off not in ('on', 'off'): + raise ValueError('Must be one of on or off.') + state = '1' if on_or_off == 'on' else '0' + return issue_generic_request( + ['shell', 'settings', 'put', 'global', 'airplane_mode_on', state], env + ) + + +def install_apk( + apk_location: str, env: env_interface.AndroidEnvInterface +) -> None: + """Installs Android World APK. + + Args: + apk_location: Location of apk. + env: The environment. + + Raises: + ValueError: If apk location does not exist. + """ + if not os.path.exists(apk_location): + raise ValueError('APK does not exist.') + issue_generic_request(['install', apk_location], env, timeout_sec=30.0) + + +def check_airplane_mode(env: env_interface.AndroidEnvInterface) -> bool: + """Checks if airplane mode is enabled. + + Args: + env: The Android environment. + + Returns: + True if airplane mode is enabled, False otherwise. + + Raises: + RuntimeError: If cannot execute airplane mode check. + """ + response = issue_generic_request( + ['shell', 'settings', 'get', 'global', 'airplane_mode_on'], env + ) + + if response.status != adb_pb2.AdbResponse.Status.OK: + raise RuntimeError( + f'ADB command failed with status {response.status}:' + f' {response.generic.output.decode()}.' + ) + + return response.generic.output.decode().strip('\n') == '1' + + +def extract_broadcast_data(raw_output: str) -> Optional[str]: + """Extracts the data from an adb broadcast command output. + + Args: + raw_output: The adb command output. + + Returns: + Extracted data as a string, or None if the result is 0. + """ + if 'Broadcast completed: result=-1, data=' in raw_output: + return raw_output.split('data=')[1].strip('"\n') + elif 'Broadcast completed: result=0' in raw_output: + return None + else: + raise ValueError(f'Unexpected broadcast output: {raw_output}') + + +def _extract_clipper_output(raw_output: str) -> str: + """Parses the clipper output from the adb command. + + Args: + raw_output: The adb command output. + + Returns: + The clipboard content as a string. + + Raises: + RuntimeError: If the adb command does not successfully execute or if the + app is not in the foreground. + """ + parsed_data = extract_broadcast_data(raw_output) + if parsed_data is not None: + return parsed_data + else: + raise RuntimeError( + 'Clipper app must be in the foreground to access clipboard. ' + 'Additionally, app privileges must be granted manually by opening the ' + 'clipper app and granting them.' + ) + + +def get_clipboard_contents(env: env_interface.AndroidEnvInterface) -> str: + """Gets the clipboard content from the Android device. + + Args: + env: The environment. + + Returns: + The clipboard content as a string. + + Raises: + RuntimeError: If the adb command does not successfully execute or if the + app is not in the foreground. + """ + if launch_app('clipper', env) is None: + raise RuntimeError( + 'Clipper app must be in the foreground to access clipboard. You may' + ' need to install clipper app.' + ) + + time.sleep(0.5) + res = issue_generic_request( + ['shell', 'am', 'broadcast', '-a', 'clipper.get'], env + ) + + if res.status != adb_pb2.AdbResponse.Status.OK: + raise RuntimeError('Failed to get clipboard content.') + + output_str = res.generic.output.decode('utf-8') + result = _extract_clipper_output(output_str) + + press_back_button(env) + return result + + +def change_orientation( + orientation: str, env: env_interface.AndroidEnvInterface +) -> None: + """Changes the screen orientation. + + Args: + orientation: str, The new orientation. Can be portrait, landscape, + reverse_portrait, or reverse_landscape. + env: The environment. + + Raises: + ValueError if invalid orientation is provided. + """ + if orientation not in _ORIENTATIONS: + raise ValueError( + f'Unknown orientation provided: {orientation} not in' + f' {_ORIENTATIONS.keys()}' + ) + command = [ + 'shell', + 'settings', + 'put', + 'system', + ] + # Turn off accelerometer. + issue_generic_request(command + ['accelerometer_rotation', '0'], env) + issue_generic_request( + command + ['user_rotation', _ORIENTATIONS[orientation]], env + ) + + +def set_clipboard_contents( + content: str, env: env_interface.AndroidEnvInterface +) -> None: + """Sets the clipboard content on the Android device. + + NOTE: If using an Emulator, the contents of your clipboard on your local + machine may transfer to the emulator when focused on the emulator. Thus the + result of this function can be overwritten just by switching windows. + + Args: + content: Content to put into clipboard. + env: The environment. + + Raises: + RuntimeError: If the adb command does not successfully execute or if the + app is not in the foreground. + """ + if launch_app('clipper', env) is None: + raise RuntimeError( + 'Clipper app must be in the foreground to access clipboard. You may' + ' need to install clipper app.' + ) + + time.sleep(0.5) + content = _adb_text_format(content) + output_str = issue_generic_request( + ['shell', 'am', 'broadcast', '-a', 'clipper.set', '-e', 'text', content], + env, + ).generic.output.decode('utf-8') + _extract_clipper_output(output_str) + press_back_button(env) + + +def grant_permissions( + activity_name: str, + permission: str, + env: env_interface.AndroidEnvInterface, +) -> None: + """Grants permissions on an activity. + + This is useful because it prevents pop-ups prompting user/agent for + permission. + + See https://developer.android.com/reference/android/Manifest.permission for + available permissions to grant. + + Args: + activity_name: The name of the activity. + permission: The permission to grant. + env: The AndroidEnv instance. + """ + issue_generic_request( + ['shell', 'pm', 'grant', activity_name, permission], + env, + ) + + +def execute_sql_command( + db_path: str, + sql_command: str, + env: env_interface.AndroidEnvInterface, +) -> adb_pb2.AdbResponse: + """Execute an arbitrary SQL command on a SQLite database file via ADB. + + Args: + db_path: The path to the SQLite database on the Android device. + sql_command: The SQL command to execute. + env: The environment. + + Returns: + The adb response received after issuing the request. + """ + set_root_if_needed(env) + adb_command = ['shell', f'sqlite3 {db_path} "{sql_command}"'] + adb_response = issue_generic_request(adb_command, env) + return adb_response + + +def get_call_state( + env: env_interface.AndroidEnvInterface, + timeout_sec: float = _DEFAULT_TIMEOUT_SECS, +) -> str: + """Query the call state and the dialed number of the phone through ADB. + + Args: + env: The Android environment interface. + timeout_sec: A timeout for the ADB operation. + + Returns: + A tuple containing the call state as a string and the dialed number as a + string. + """ + adb_args = ['shell', 'dumpsys', 'telephony.registry'] + response = issue_generic_request(adb_args, env, timeout_sec) + + output = response.generic.output.decode('utf-8') + state_match = re.search(r'mCallState=(\d)', output) + + state = 'UNKNOWN' + + if state_match: + state_code = state_match.group(1) + if state_code == '0': + state = 'IDLE' + elif state_code == '1': + state = 'RINGING' + elif state_code == '2': + state = 'OFFHOOK' + + return state + + +def call_emulator( + env: env_interface.AndroidEnvInterface, + phone_number: str, + timeout_sec: float = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Simulate an incoming call in an emulator using ADB. + + Args: + env: The Android environment interface. + phone_number: The incoming phone number. + timeout_sec: A timeout for the ADB operation. + + Returns: + adb_pb2.AdbResponse: A response object containing the ADB operation result. + """ + escaped_phone_number = re.sub(r'[^0-9+]', '', phone_number) + adb_args = ['emu', 'gsm', 'call', f'{escaped_phone_number}'] + response = issue_generic_request(adb_args, env, timeout_sec) + return response + + +def end_call_if_active( + env: 'env_interface.AndroidEnvInterface', + timeout_sec: float = _DEFAULT_TIMEOUT_SECS, +) -> None: + """Ends phone call if on an active call.""" + current_state = get_call_state(env, timeout_sec) + + # This check is crucial. Otherwise pressing endcall key results in black + # screen, potentially because it's simulating turning display off? + if current_state in ('OFFHOOK', 'RINGING'): + adb_args = ['shell', 'input', 'keyevent', 'KEYCODE_ENDCALL'] + issue_generic_request(adb_args, env, timeout_sec) + + +def clear_android_emulator_call_log( + env: env_interface.AndroidEnvInterface, + timeout_sec: float = _DEFAULT_TIMEOUT_SECS, +) -> None: + """Clears the call log of a specific Android emulator using the Android environment interface. + + Args: + env: The Android environment interface. + timeout_sec: A timeout for the ADB operation. + """ + adb_args = ['shell', 'content', 'delete', '--uri', 'content://call_log/calls'] + issue_generic_request(adb_args, env, timeout_sec) + + +def call_phone_number( + env: env_interface.AndroidEnvInterface, + phone_number: str, + timeout_sec: float = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Initiate a phone call using ADB. + + Args: + env: The Android environment interface. + phone_number: The phone number to dial. + timeout_sec: A timeout for the ADB operation. + + Returns: + A response object containing the ADB operation result. + """ + escaped_phone_number = re.sub(r'[^0-9]', '', phone_number) + adb_args = [ + 'shell', + 'am', + 'start', + '-a', + 'android.intent.action.CALL', + '-d', + f'tel:{escaped_phone_number}', + ] + return issue_generic_request(adb_args, env, timeout_sec) + + +def text_emulator( + env: env_interface.AndroidEnvInterface, + phone_number: str, + message: str, + timeout_sec: float = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Simulate an incoming text message in an emulator using ADB. + + Args: + env: The Android environment interface. + phone_number: The sender's phone number. + message: The text message content. + timeout_sec: A timeout for the ADB operation. + + Returns: + A response object containing the ADB operation result. + """ + escaped_phone_number = re.sub(r'[^0-9+]', '', phone_number) + adb_args = [ + 'emu', + 'sms', + 'send', + f'{escaped_phone_number}', + f'{message}', + ] + response = issue_generic_request(adb_args, env, timeout_sec) + return response + + +def set_default_app( + setting_key: str, + package_name: str, + env: env_interface.AndroidEnvInterface, + timeout_sec: float = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Set the default application for a given type using ADB. + + Args: + setting_key: The setting key for the default application type (e.g., + 'sms_default_application'). + package_name: The package name of the application to be set as default. + env: The Android environment interface. + timeout_sec: A timeout for the ADB operation. + + Returns: + A response object containing the ADB operation result. + """ + adb_args = ['shell', 'settings', 'put', 'secure', setting_key, package_name] + response = issue_generic_request(adb_args, env, timeout_sec) + return response + + +def disable_headsup_notifications( + env: env_interface.AndroidEnvInterface, + timeout_sec: float = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Disables the heads up notifications. + + Args: + env: The Android environment interface. + timeout_sec: A timeout for the ADB operation. + + Returns: + A response object containing the ADB operation result. + """ + adb_args = [ + 'shell', + 'settings', + 'put', + 'global', + 'heads_up_notifications_enabled', + '0', + ] + response = issue_generic_request(adb_args, env, timeout_sec) + return response + + +def enable_headsup_notifications( + env: env_interface.AndroidEnvInterface, + timeout_sec: float = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Enables the heads up notifications. + + Args: + env: The Android environment interface. + timeout_sec: A timeout for the ADB operation. + + Returns: + A response object containing the ADB operation result. + """ + adb_args = [ + 'shell', + 'settings', + 'put', + 'global', + 'heads_up_notifications_enabled', + '1', + ] + response = issue_generic_request(adb_args, env, timeout_sec) + return response + + +def put_settings( + namespace: adb_pb2.AdbRequest.SettingsRequest.Namespace, + key: str, + value: str, + env: env_interface.AndroidEnvInterface, +) -> adb_pb2.AdbResponse: + """Change a setting in the Android system via ADB. + + Args: + namespace: The namespace in which the setting resides (SYSTEM, SECURE, + GLOBAL). + key: The key of the setting to change. + value: The new value for the setting. + env: The Android environment interface. + + Returns: + The adb response received after issuing the request. + """ + if not key: + raise ValueError('Key must be provided.') + if not value: + raise ValueError('Value must be provided.') + settings_request = adb_pb2.AdbRequest.SettingsRequest( + name_space=namespace, + put=adb_pb2.AdbRequest.SettingsRequest.Put(key=key, value=value), + ) + adb_request = adb_pb2.AdbRequest(settings=settings_request) + return env.execute_adb_call(adb_request) + + +def delete_contacts( + env: env_interface.AndroidEnvInterface, + timeout_sec: float = _DEFAULT_TIMEOUT_SECS, +) -> adb_pb2.AdbResponse: + """Deletes all contacts. + + Args: + env: The Android environment interface. + timeout_sec: A timeout for the ADB operation. + + Returns: + A response object containing the ADB operation result. + """ + adb_args = [ + 'shell', + 'pm', + 'clear', + 'com.android.providers.contacts', + ] + response = issue_generic_request(adb_args, env, timeout_sec) + return response + + +def _parse_screen_size_response(response: str) -> tuple[int, int]: + """Parse the adb response to extract screen size. + + Args: + response: The adb response string. + + Returns: + The screen width and height in pixels. + """ + match = re.search(r'Physical size: (\d+)x(\d+)', response) + if match: + width, height = map(int, match.groups()) + return width, height + else: + raise ValueError( + f'Screen size information not found in adb response: "{response}"' + ) + + +def get_screen_size(env: env_interface.AndroidEnvInterface) -> tuple[int, int]: + """Get the screen size in pixels of an Android device via ADB. + + Args: + env: The environment. + + Returns: + The screen width and height in pixels. + """ + adb_command = ['shell', 'wm size'] + adb_response = issue_generic_request(adb_command, env) + return _parse_screen_size_response( + adb_response.generic.output.decode('utf-8') + ) + + +def get_logical_screen_size( + env: env_interface.AndroidEnvInterface, +) -> tuple[int, int]: + """Returns the logical screen size. + + The logical screen size is the screen size that applications use to render + their interfaces which might be different than the physical screen size when + orientation/resolution changes. The coordinates we get from A11y tree are + based on the logical screen size. + + Args: + env: The AndroidEnv interface. + + Returns: + The logical screen size in (width, height). + """ + response = issue_generic_request( + 'shell dumpsys input | grep logicalFrame', env + ) + if response.status: + raw_output = response.generic.output.decode('utf-8') + pattern = r'logicalFrame=\[0, 0, (\d+), (\d+)\]' + matches = re.findall(pattern, raw_output) + for m in matches: + if int(m[0]) == 0 and int(m[1]) == 0: + continue + width, height = (int(m[0]), int(m[1])) + return (width, height) + raise ValueError('Failed to get logical screen size.') + + +def get_physical_frame_boundary( + env: env_interface.AndroidEnvInterface, +) -> tuple[int, int, int, int]: + """Returns the physical frame boundary. + + Args: + env: The AndroidEnv interface. + + Returns: + First two integers are the coordinates for top left corner, last two are for + lower right corner. All coordinates are given in portrait orientation. + """ + response = issue_generic_request( + 'shell dumpsys input | grep physicalFrame', env + ) + if response.status: + raw_output = response.generic.output.decode('utf-8') + pattern = r'physicalFrame=\[(\d+), (\d+), (\d+), (\d+)\]' + matches = re.findall(pattern, raw_output) + for m in matches: + if ( + int(m[0]) == 0 + and int(m[1]) == 0 + and int(m[2]) == 0 + and int(m[3]) == 0 + ): + continue + orientation = get_orientation(env) + if orientation == 0 or orientation == 2: + return (int(m[0]), int(m[1]), int(m[2]), int(m[3])) + return (int(m[1]), int(m[0]), int(m[3]), int(m[2])) + raise ValueError('Failed to get physical frame boundary.') + + +def get_orientation( + env: env_interface.AndroidEnvInterface, +) -> int: + """Returns the current screen orientation. + + The returned value follows the normal convention, 0 for portrait, 1 for + landscape, 2 for reverse portrait, 3 for reverse landscape. + + Args: + env: The AndroidEnv interface. + + Returns: + The screen orientation. + """ + response = issue_generic_request( + 'shell dumpsys window | grep mCurrentRotation', env + ) + if response.status: + raw_output = response.generic.output.decode('utf-8') + pattern = r'mCurrentRotation=ROTATION_(\d+)' + matches = re.findall(pattern, raw_output) + for m in matches: + return int(m) // 90 + raise ValueError('Failed to get orientation.') + + +def set_screen_size( + width: int, + height: int, + env: env_interface.AndroidEnvInterface, +) -> adb_pb2.AdbResponse: + """Sets the (logical) screen size (resolution) of the Android device via ADB. + + Args: + width: The desired screen width. + height: The desired screen height. + env: The AndroidEnv interface. + + Returns: + The adb response received after issuing the request. + """ + # Command will fail if width equals height. + if width <= 0 or height <= 0 or width == height: + raise ValueError( + 'Screen size not valid (need to be positive, width can not equal' + ' height).' + ) + # Construct the ADB command for setting screen size + adb_command = ['shell', f'wm size {width}x{height}'] + + # Issue the command and return the response + return issue_generic_request(adb_command, env) + + +def retry(n: int) -> Callable[[Any], Any]: + """Decorator to retry ADB commands.""" + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + def wrapper(*args: Any, **kwargs: Any) -> T: + attempts = 0 + while attempts < n: + try: + return func(*args, **kwargs) + except errors.AdbControllerError: + attempts += 1 + if attempts >= n: + raise + print(f'Could not execute {func}. Retrying...') + time.sleep(2) + except Exception as exc: + raise exc + + return wrapper + + return decorator + + +def set_root_if_needed( + env: env_interface.AndroidEnvInterface, timeout_sec: Optional[float] = None +) -> adb_pb2.AdbResponse: + """Checks if ADB is running as root, and if not, attempts to set root. + + Args: + env: The environment. + timeout_sec: A timeout to use for this operation. + + Returns: + bool: True if root is set (or was already set), False otherwise. + """ + response = issue_generic_request(['shell', 'whoami'], env, timeout_sec) + + if response.generic.output.decode('utf-8').strip() == 'root': + return response + + return issue_generic_request(['root'], env, timeout_sec) + + +def uiautomator_dump(env) -> str: + """Issues a uiautomator dump request and returns the UI hierarchy.""" + dump_args = 'shell uiautomator dump /sdcard/window_dump.xml' + issue_generic_request(dump_args, env) + + read_args = 'shell cat /sdcard/window_dump.xml' + response = issue_generic_request(read_args, env) + + return response.generic.output.decode('utf-8') diff --git a/evaluation/AndroidWorld/android_world/env/adb_utils_test.py b/evaluation/AndroidWorld/android_world/env/adb_utils_test.py new file mode 100644 index 0000000..53d407b --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/adb_utils_test.py @@ -0,0 +1,267 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for adb_utils.""" + +from unittest import mock + +from absl.testing import absltest +from android_env import env_interface +from android_env.proto import adb_pb2 +from android_world.env import adb_utils + + +class AdbTestSetup(absltest.TestCase): + + def setUp(self): + super().setUp() + self.mock_env = mock.patch.object( + env_interface, 'AndroidEnvInterface', autospec=True + ).start() + self.mock_issue_generic_request = mock.patch.object( + adb_utils, 'issue_generic_request', autospec=True + ).start() + + def tearDown(self): + super().tearDown() + mock.patch.stopall() + + +class PhoneUtilsTest(AdbTestSetup): + + def test_get_call_state_idle(self): + dumpsys_content = """last known state: +Phone Id=0 + mCallState=0 + mRingingCallState=0 + mForegroundCallState=0""" + # Setup + mock_dumpsys_response = adb_pb2.AdbResponse() + mock_dumpsys_response.generic.output = dumpsys_content.encode('utf-8') + self.mock_issue_generic_request.return_value = mock_dumpsys_response + + # Act + result = adb_utils.get_call_state(self.mock_env) + + # Assert + self.assertEqual(result, 'IDLE') + + def test_get_call_state_ringing(self): + dumpsys_content = """last known state: +Phone Id=0 + mCallState=2 + mRingingCallState=0 + mForegroundCallState=0""" + # Setup + mock_dumpsys_response = adb_pb2.AdbResponse() + mock_dumpsys_response.generic.output = dumpsys_content.encode('utf-8') + self.mock_issue_generic_request.return_value = mock_dumpsys_response + + # Act + result = adb_utils.get_call_state(self.mock_env) + + # Assert + self.assertEqual(result, 'OFFHOOK') + + def test_call_emulator(self): + mock_response = adb_pb2.AdbResponse() + mock_response.generic.output = b'Success' + self.mock_issue_generic_request.return_value = mock_response + + phone_number = '+123456789' + result = adb_utils.call_emulator(self.mock_env, phone_number) + + self.assertEqual(result.generic.output.decode(), 'Success') + + @mock.patch.object(adb_utils, 'get_call_state', autospec=True) + def test_end_call_if_active(self, mock_get_call_state): + mock_get_call_state.return_value = 'OFFHOOK' + adb_utils.end_call_if_active(self.mock_env) + + self.mock_issue_generic_request.assert_called() + + def test_clear_android_emulator_call_log(self): + adb_utils.clear_android_emulator_call_log(self.mock_env) + self.mock_issue_generic_request.assert_called() + + def test_call_phone_number(self): + mock_response = adb_pb2.AdbResponse() + mock_response.generic.output = b'Success' + self.mock_issue_generic_request.return_value = mock_response + + phone_number = '123456789' + result = adb_utils.call_phone_number(self.mock_env, phone_number) + + self.assertEqual(result.generic.output.decode(), 'Success') + + def test_text_emulator(self): + mock_response = adb_pb2.AdbResponse() + mock_response.generic.output = b'Success' + self.mock_issue_generic_request.return_value = mock_response + + phone_number = '+123456789' + message = 'Hello, world!' + result = adb_utils.text_emulator(self.mock_env, phone_number, message) + + self.assertEqual(result.generic.output.decode(), 'Success') + + +class AdbSettingsTest(AdbTestSetup): + + def test_set_default_app(self): + mock_response = adb_pb2.AdbResponse() + mock_response.generic.output = b'Success' + self.mock_issue_generic_request.return_value = mock_response + + setting_key = 'sms_default_application' + package_name = 'com.example.app' + result = adb_utils.set_default_app(self.mock_env, setting_key, package_name) + + self.assertEqual(result.generic.output.decode(), 'Success') + + def test_successful_put_operation(self): + self.mock_env.execute_adb_call.return_value = adb_pb2.AdbResponse() + + # Execute the function + response = adb_utils.put_settings( + namespace=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, + key='example_key', + value='example_value', + env=self.mock_env, + ) + + # Assertions + self.assertIsInstance(response, adb_pb2.AdbResponse) + self.mock_env.execute_adb_call.assert_called_once() + + def test_invalid_inputs_put_operation(self): + self.mock_env.execute_adb_call.return_value = adb_pb2.AdbResponse() + + # Invalid namespace (non-enum value) + with self.assertRaises(ValueError): + adb_utils.put_settings( + namespace='INVALID', # This should be an enum, not a string + key='example_key', + value='example_value', + env=self.mock_env, + ) + + # Empty key + with self.assertRaises(ValueError): + adb_utils.put_settings( + namespace=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, + key='', + value='example_value', + env=self.mock_env, + ) + + # Empty value + with self.assertRaises(ValueError): + adb_utils.put_settings( + namespace=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, + key='example_key', + value='', + env=self.mock_env, + ) + + +class AdbTypingTest(AdbTestSetup): + + def test_can_type_text(self): + with mock.patch.object( + env_interface.AndroidEnvInterface, 'execute_adb_call' + ) as mock_execute_adb_call: + mock_execute_adb_call.return_value = adb_pb2.AdbResponse( + status=adb_pb2.AdbResponse.Status.OK + ) + adb_utils.type_text('Type some\ntext', self.mock_env) + expected_calls = [ + mock.call( + adb_pb2.AdbRequest( + input_text=adb_pb2.AdbRequest.InputText(text='Type'), + timeout_sec=adb_utils._DEFAULT_TIMEOUT_SECS, + ) + ), + mock.call( + adb_pb2.AdbRequest( + input_text=adb_pb2.AdbRequest.InputText(text='%s'), + timeout_sec=adb_utils._DEFAULT_TIMEOUT_SECS, + ) + ), + mock.call( + adb_pb2.AdbRequest( + input_text=adb_pb2.AdbRequest.InputText(text='some'), + timeout_sec=adb_utils._DEFAULT_TIMEOUT_SECS, + ) + ), + mock.call( + adb_pb2.AdbRequest( + press_button=adb_pb2.AdbRequest.PressButton( + button=adb_pb2.AdbRequest.PressButton.ENTER + ), + timeout_sec=adb_utils._DEFAULT_TIMEOUT_SECS, + ) + ), + mock.call( + adb_pb2.AdbRequest( + input_text=adb_pb2.AdbRequest.InputText(text='text'), + timeout_sec=adb_utils._DEFAULT_TIMEOUT_SECS, + ) + ), + ] + mock_execute_adb_call.assert_has_calls(expected_calls) + + +class TestExtractBroadcastData(absltest.TestCase): + + def test_successful_data_extraction(self): + raw_output = 'Broadcast completed: result=-1, data="Test data"\n' + expected_result = 'Test data' + result = adb_utils.extract_broadcast_data(raw_output) + self.assertEqual(result, expected_result) + + def test_result_zero_returns_none(self): + raw_output = 'Broadcast completed: result=0\n' + result = adb_utils.extract_broadcast_data(raw_output) + self.assertIsNone(result) + + def test_unexpected_output_raises_error(self): + raw_output = 'Unexpected output format' + with self.assertRaises(ValueError): + adb_utils.extract_broadcast_data(raw_output) + + +class TestScreenUtils(absltest.TestCase): + + def test_parse_screen_size_response_success(self): + """Test successful parsing of screen size from adb response.""" + response = 'Physical size: 1080x2400' + expected_size = (1080, 2400) + self.assertEqual( + adb_utils._parse_screen_size_response(response), expected_size + ) + + def test_parse_screen_size_response_failure(self): + """Test parsing failure when adb response is in an unexpected format.""" + response = 'Invalid response format' + with self.assertRaises(ValueError) as context: + adb_utils._parse_screen_size_response(response) + self.assertIn( + 'Screen size information not found in adb response', + str(context.exception), + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/evaluation/AndroidWorld/android_world/env/android_world_controller.py b/evaluation/AndroidWorld/android_world/env/android_world_controller.py new file mode 100644 index 0000000..62e3818 --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/android_world_controller.py @@ -0,0 +1,322 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Controller for Android that adds UI tree information to the observation.""" + +import contextlib +import enum +import os +import time +from typing import Any +from typing import cast +from typing import Optional +from absl import logging +from android_env import env_interface +from android_env import loader +from android_env.components import config_classes +from android_env.proto.a11y import android_accessibility_forest_pb2 +from android_env.wrappers import a11y_grpc_wrapper +from android_env.wrappers import base_wrapper +from android_world.env import adb_utils +from android_world.env import representation_utils +from android_world.utils import file_utils +import dm_env + + +def _has_wrapper( + env: env_interface.AndroidEnvInterface, + target_wrapper: Any, +) -> bool: + """Checks recursively if an environment object has a certain wrapper. + + Args: + env: The environment object potentially wrapped. + target_wrapper: The wrapper type to search for. + + Returns: + True if the target_wrapper is found, otherwise False. + """ + if isinstance(env, target_wrapper): + return True + elif hasattr(env, '_env'): + return _has_wrapper(env._env, target_wrapper) # pylint: disable=protected-access + else: + return False + + +def get_a11y_tree( + env: env_interface.AndroidEnvInterface, + max_retries: int = 5, + sleep_duration: float = 1.0, +) -> android_accessibility_forest_pb2.AndroidAccessibilityForest: + """Gets a11y tree. + + Args: + env: AndroidEnv. + max_retries: Maximum number of retries to get a11y tree. + sleep_duration: Time to sleep between each retry in seconds. + + Returns: + A11y tree. + + Raises: + RuntimeError: If the a11y tree was not able to be retrieved. + """ + if not _has_wrapper(env, a11y_grpc_wrapper.A11yGrpcWrapper): + raise ValueError( + 'Must use a11y_grpc_wrapper.A11yGrpcWrapper to get the a11y tree.' + ) + env = cast(a11y_grpc_wrapper.A11yGrpcWrapper, env) + if adb_utils.retry(3)(adb_utils.check_airplane_mode)(env): + logging.warning( + 'Airplane mode is on -- cannot retrieve a11y tree via gRPC. Turning' + ' it off...' + ) + logging.info('Enabling networking...') + env.attempt_enable_networking() + time.sleep(1.0) + + forest: Optional[ + android_accessibility_forest_pb2.AndroidAccessibilityForest + ] = None + for _ in range(max_retries): + try: + forest = env.accumulate_new_extras()['accessibility_tree'][-1] # pytype:disable=attribute-error + return forest + except KeyError: + logging.warning('Could not get a11y tree, retrying.') + time.sleep(sleep_duration) + + if forest is None: + raise RuntimeError('Could not get a11y tree.') + return forest + + +_TASK_PATH = '/tmp/default.textproto' +DEFAULT_ADB_PATH = '~/Android/Sdk/platform-tools/adb' + + +# UI tree-specific keys that are added to observations: + +# The forest is essentially a comprehensive snapshot of all user interface +# elements currently displayed on an Android device's screen. Each 'tree' in +# this 'forest' represents the accessibility details of a different window or +# screen section, providing structured information. The tree's origin is from +# the AccessibilityService. Please see the following for more detail: +# https://developer.android.com/reference/android/accessibilityservice/AccessibilityService + +OBSERVATION_KEY_FOREST = 'forest' +# UI elements are specific nodes extracted from forest. See +# representation_utils.forest_to_ui_elements for details. +OBSERVATION_KEY_UI_ELEMENTS = 'ui_elements' + + +class A11yMethod(enum.Enum): + """Method to get a11y tree.""" + + # Custom gRPC wrapper that uses a11y forwarder app. + A11Y_FORWARDER_APP = 'a11y_forwarder_app' + + # From `uiautomator dump``. + UIAUTOMATOR = 'uiautomator' + + +def apply_a11y_forwarder_app_wrapper( + env: env_interface.AndroidEnvInterface, install_a11y_forwarding_app: bool +) -> env_interface.AndroidEnvInterface: + return a11y_grpc_wrapper.A11yGrpcWrapper( + env, + install_a11y_forwarding=install_a11y_forwarding_app, + start_a11y_service=True, + enable_a11y_tree_info=True, + latest_a11y_info_only=True, + ) + + +class AndroidWorldController(base_wrapper.BaseWrapper): + """Controller for an Android instance that adds accessibility tree data. + + The Accessibility Tree in Android is a tree-based structure, originally for + for assisting accessibility services. It provides information about UI + elements (like text, buttons, and images) in a hierarchical format. The tree + includes details such as the properties and actions available for each + element. + """ + + def __init__( + self, + env: env_interface.AndroidEnvInterface, + a11y_method: A11yMethod = A11yMethod.A11Y_FORWARDER_APP, + install_a11y_forwarding_app: bool = True, + ): + self._original_env = env + if a11y_method == A11yMethod.A11Y_FORWARDER_APP: + self._env = apply_a11y_forwarder_app_wrapper( + env, install_a11y_forwarding_app + ) + self._env.reset() # Initializes required server services in a11y wrapper. + else: + self._env = env + self._a11y_method = a11y_method + + @property + def device_screen_size(self) -> tuple[int, int]: + """Returns the physical screen size of the device: (width, height).""" + return adb_utils.get_screen_size(self._env) + + @property + def logical_screen_size(self) -> tuple[int, int]: + """Returns the logical screen size of the device. + + This will be different with the physical size if orientation or resolution + is changed. + """ + return adb_utils.get_logical_screen_size(self._env) + + @property + def env(self) -> env_interface.AndroidEnvInterface: + return self._env + + def refresh_env(self): + # pylint: disable=protected-access + # pytype: disable=attribute-error + # Reconnect to emulator and reload a11y wrapper in case we lose connection. + self._env = get_controller( + console_port=self.env._coordinator._simulator._config.emulator_launcher.emulator_console_port, + adb_path=self.env._coordinator._simulator._config.adb_controller.adb_path, + grpc_port=self.env._coordinator._simulator._config.emulator_launcher.grpc_port, + ).env + # pylint: enable=protected-access + # pytype: enable=attribute-error + + def _get_a11y_forest( + self, + ) -> android_accessibility_forest_pb2.AndroidAccessibilityForest: + return get_a11y_tree(self._env) + + def get_a11y_forest( + self, + ) -> android_accessibility_forest_pb2.AndroidAccessibilityForest: + """Returns the most recent a11y forest from the device.""" + try: + return self._get_a11y_forest() + except RuntimeError: + print( + 'Could not get a11y tree. Reconnecting to Android, reinitializing' + ' AndroidEnv, and restarting a11y forwarding.' + ) + self.refresh_env() + return self._get_a11y_forest() + + def get_ui_elements(self) -> list[representation_utils.UIElement]: + """Returns the most recent UI elements from the device.""" + if self._a11y_method == A11yMethod.A11Y_FORWARDER_APP: + return representation_utils.forest_to_ui_elements( + self.get_a11y_forest(), + exclude_invisible_elements=True, + ) + else: + return representation_utils.xml_dump_to_ui_elements( + adb_utils.uiautomator_dump(self._env) + ) + + def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + """Adds a11y tree info to the observation.""" + if self._a11y_method == A11yMethod.A11Y_FORWARDER_APP: + forest = self.get_a11y_forest() + ui_elements = representation_utils.forest_to_ui_elements( + forest, + exclude_invisible_elements=True, + ) + else: + forest = None + ui_elements = self.get_ui_elements() + timestep.observation[OBSERVATION_KEY_FOREST] = forest + timestep.observation[OBSERVATION_KEY_UI_ELEMENTS] = ui_elements + return timestep + + def pull_file( + self, remote_db_file_path: str, timeout_sec: Optional[float] = None + ) -> contextlib._GeneratorContextManager[str]: + """Pulls a file from the device to a temporary directory. + + The directory will be deleted when the context manager exits. + Args: + remote_db_file_path: The path to the file on the device. + timeout_sec: Timeout in seconds for the adb calls. + + Returns: + The path to the temporary directory containing the file. + """ + remote_db_directory = os.path.dirname(remote_db_file_path) + return file_utils.tmp_directory_from_device( + remote_db_directory, self.env, timeout_sec + ) + + def push_file( + self, + local_db_file_path: str, + remote_db_file_path: str, + timeout_sec: Optional[float] = None, + ) -> None: + """Pushes a local file to the device.""" + + remote_db_directory = os.path.dirname(remote_db_file_path) + + # First delete old .db, .db-wal, and .db-shm files. + file_utils.clear_directory(remote_db_directory, self) + file_utils.copy_data_to_device( + local_db_file_path, + remote_db_file_path, + self.env, + timeout_sec, + ) + + +def _write_default_task_proto() -> str: + with open(_TASK_PATH, 'w') as f: + f.write("""\ +id: "default" + +name: "Default task for device control." +description: "Empty task" + +max_episode_sec: 7200 # Prevent infinite episodes. + """) + return _TASK_PATH + + +def get_controller( + console_port: int = 5554, + adb_path: str = DEFAULT_ADB_PATH, + grpc_port: int = 8554, +) -> AndroidWorldController: + """Creates a controller by connecting to an existing Android environment.""" + + config = config_classes.AndroidEnvConfig( + task=config_classes.FilesystemTaskConfig( + path=_write_default_task_proto() + ), + simulator=config_classes.EmulatorConfig( + emulator_launcher=config_classes.EmulatorLauncherConfig( + emulator_console_port=console_port, + adb_port=console_port + 1, + grpc_port=grpc_port, + ), + adb_controller=config_classes.AdbControllerConfig(adb_path=adb_path), + ), + ) + android_env_instance = loader.load(config) + logging.info('Setting up AndroidWorldController.') + return AndroidWorldController(android_env_instance) diff --git a/evaluation/AndroidWorld/android_world/env/android_world_controller_test.py b/evaluation/AndroidWorld/android_world/env/android_world_controller_test.py new file mode 100644 index 0000000..c01f3f1 --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/android_world_controller_test.py @@ -0,0 +1,193 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from unittest import mock + +from absl.testing import absltest +from android_env import env_interface +from android_env.wrappers import a11y_grpc_wrapper +from android_world.env import adb_utils +from android_world.env import android_world_controller +from android_world.env import representation_utils +from android_world.utils import fake_adb_responses +from android_world.utils import file_test_utils +from android_world.utils import file_utils +import dm_env + + +def create_file_with_contents(contents: str) -> str: + temp_dir = tempfile.mkdtemp() + file_path = os.path.join(temp_dir, 'file.txt') + with open(file_path, 'w') as f: + f.write(contents) + return file_path + + +class AndroidWorldControllerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + mock_issue_generic_request = self.enter_context( + mock.patch.object(adb_utils, 'issue_generic_request') + ) + mock_issue_generic_request.return_value = ( + fake_adb_responses.create_successful_generic_response( + 'Physical size: 100x200' + ) + ) + self.mock_a11y_wrapper = self.enter_context( + mock.patch.object( + a11y_grpc_wrapper, + 'A11yGrpcWrapper', + spec=a11y_grpc_wrapper.A11yGrpcWrapper, + ) + ) + + self.table_name = 'events' + + self.mock_copy_db = self.enter_context( + mock.patch.object( + file_utils, + 'tmp_directory_from_device', + side_effect=file_test_utils.mock_tmp_directory_from_device, + ) + ) + self.mock_copy_data_to_device = self.enter_context( + mock.patch.object( + file_utils, + 'copy_data_to_device', + side_effect=file_test_utils.mock_copy_data_to_device, + ) + ) + + self.mock_remove_files = self.enter_context( + mock.patch.object( + file_utils, + 'clear_directory', + side_effect=file_test_utils.mock_remove_files, + ) + ) + + def test_initialization(self): + mock_env = mock.Mock(spec=env_interface.AndroidEnvInterface) + + env = android_world_controller.AndroidWorldController(mock_env) + + self.mock_a11y_wrapper.assert_called_with( + mock_env, + install_a11y_forwarding=True, + start_a11y_service=True, + enable_a11y_tree_info=True, + latest_a11y_info_only=True, + ) + env._env.reset.assert_called_once() + + def test_screen_size(self): + mock_base_env = mock.Mock(spec=env_interface.AndroidEnvInterface) + + env = android_world_controller.AndroidWorldController(mock_base_env) + + self.assertEqual(env.device_screen_size, (100, 200)) + + @mock.patch.object(adb_utils, 'get_logical_screen_size') + @mock.patch.object(android_world_controller, 'get_a11y_tree') + @mock.patch.object(representation_utils, 'forest_to_ui_elements') + def test_process_timestep( + self, mock_forest_to_ui, mock_get_a11y_tree, mock_get_logical_screen_size + ): + mock_base_env = mock.Mock(spec=env_interface.AndroidEnvInterface) + env = android_world_controller.AndroidWorldController(mock_base_env) + mock_forest = mock.Mock() + mock_ui_elements = mock.Mock() + mock_get_logical_screen_size.return_value = (100, 200) + mock_get_a11y_tree.return_value = mock_forest + mock_forest_to_ui.return_value = mock_ui_elements + timestep = dm_env.TimeStep( + observation={}, reward=None, discount=None, step_type=None + ) + + processed_timestep = env._process_timestep(timestep) + + self.assertEqual(processed_timestep.observation['forest'], mock_forest) + self.assertEqual( + processed_timestep.observation['ui_elements'], mock_ui_elements + ) + mock_forest_to_ui.assert_called_with( + mock_forest, + exclude_invisible_elements=True, + ) + + @mock.patch.object(adb_utils, 'check_airplane_mode') + @mock.patch.object(android_world_controller, 'get_controller') + @mock.patch.object(android_world_controller, '_has_wrapper') + @mock.patch.object( + android_world_controller.AndroidWorldController, 'refresh_env' + ) + def test_refresh_env( + self, + mock_refresh_env, + mock_has_wrapper, + mock_get_controller, + mock_check_airplane_mode, + ): + del mock_has_wrapper, mock_get_controller, mock_check_airplane_mode + mock_base_env = mock.Mock(spec=env_interface.AndroidEnvInterface) + env = android_world_controller.AndroidWorldController(mock_base_env) + unused_mock_check_airplane_mode = False + env._env.accumulate_new_extras.side_effect = [ + {}, + {}, + {}, + {}, + {}, + {'accessibility_tree': ['success']}, + ] + + forest = env.get_a11y_forest() + + self.assertEqual(forest, 'success') + mock_refresh_env.assert_called_once() + + def test_pull_file(self): + file_contents = 'test file contents' + remote_file_path = create_file_with_contents(file_contents) + mock_base_env = mock.Mock(spec=env_interface.AndroidEnvInterface) + env = android_world_controller.AndroidWorldController(mock_base_env) + + with env.pull_file(remote_file_path) as local_dir: + local_path = os.path.split(remote_file_path)[1] + local_file = open(os.path.join(local_dir, local_path), 'r') + self.assertEqual(open(remote_file_path, 'r').read(), local_file.read()) + + self.mock_copy_db.assert_called_once_with( + os.path.dirname(remote_file_path), env._env, None + ) + + def test_push_file(self): + old_file_contents = 'test file contents' + new_file_contents = 'new file' + remote_file_path = create_file_with_contents(old_file_contents) + mock_base_env = mock.Mock(spec=env_interface.AndroidEnvInterface) + env = android_world_controller.AndroidWorldController(mock_base_env) + new_file = create_file_with_contents(new_file_contents) + + env.push_file(new_file, remote_file_path, None) + + self.assertEqual(open(remote_file_path, 'r').read(), new_file_contents) + + +if __name__ == '__main__': + absltest.main() diff --git a/evaluation/AndroidWorld/android_world/env/device_constants.py b/evaluation/AndroidWorld/android_world/env/device_constants.py new file mode 100644 index 0000000..0f9417c --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/device_constants.py @@ -0,0 +1,44 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants for the Pixel 6, API 33, emulator.""" + +import datetime + +# Screen dimensions of Pixel 6. +SCREEN_HEIGHT, SCREEN_WIDTH = 2400, 1080 + +# Where data on emulator is stored. +EMULATOR_DATA = "/storage/emulated/0/" + +# Location where app snapshots are stored. +SNAPSHOT_DATA = "/data/data/android_world/snapshots" + +# keep-sorted start +AUDIORECORDER_DATA = "/storage/emulated/0/Android/data/com.dimowner.audiorecorder/files/Music/records" +DOWNLOAD_DATA = "/storage/emulated/0/Download" +GALLERY_DATA = "/sdcard/DCIM" +MARKOR_DATA = "/storage/emulated/0/Documents/Markor" +MUSIC_DATA = "/sdcard/Music" +OSMAND_DATA = "/storage/emulated/0/Android/data/net.osmand/files" +PHOTOS_DATA = "/sdcard/Pictures" +VIDEOS_DATA = "/sdcard/Movies" +# keep-sorted end + +# Every task starts October 15, 2023 @ 15:34:00. +TIMEZONE = "UTC" +DT = datetime.datetime(2023, 10, 15, 15, 34, 0, tzinfo=datetime.timezone.utc) + +# Format the datetime object into the Android date-time format +ANDROID_DT = DT.strftime("%m%d%H%M%y.%S") diff --git a/evaluation/AndroidWorld/android_world/env/env_launcher.py b/evaluation/AndroidWorld/android_world/env/env_launcher.py new file mode 100644 index 0000000..e87da2e --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/env_launcher.py @@ -0,0 +1,116 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Launches the environment used in the benchmark.""" + +import resource + +from absl import logging +from android_world.env import android_world_controller +from android_world.env import interface +from android_world.env.setup_device import setup +from android_world.utils import datetime_utils + + +# AndroidWorld is tested and developed on Pixel 6 with API 33. Other +# configurations may be supported, but are not yet tested. +_ANDROID_WORLD_API_LEVEL = 33 + + +def _get_env( + console_port: int, adb_path: str, grpc_port: int +) -> interface.AsyncEnv: + """Creates an AsyncEnv by connecting to an existing Android environment.""" + controller = android_world_controller.get_controller( + console_port, adb_path, grpc_port + ) + return interface.AsyncAndroidEnv(controller) + + +def _increase_file_descriptor_limit(limit: int = 32768): + """Increases the file descriptor limit to the given limit. + + This helps with different platforms having different limits, which can result + from too many open files, sockets, or pipes, resulting in "OSError: [Errno 24] + Too many open files". + + Args: + limit: The new file descriptor limit. The default value was determined + experimentally to not raise too many open files error. + """ + try: + _, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + if limit > hard: + logging.warning( + ( + "Requested limit %d exceeds the system's hard limit %d. Setting" + ' to the maximum allowed value.' + ), + limit, + hard, + ) + limit = hard + resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) + logging.info('File descriptor limit set to %d.', limit) + except ValueError as e: + logging.exception('Failed to set file descriptor limit: %s', e) + + +def setup_env( + env: interface.AsyncEnv, + emulator_setup: bool = False, + freeze_datetime: bool = True, +) -> None: + """Performs environment setup and validation.""" + _increase_file_descriptor_limit() + if emulator_setup: + logging.info('Setting up apps on the emulator.') + setup.setup_apps(env) + if freeze_datetime: + logging.info('Freezing datetime.') + datetime_utils.setup_datetime(env.controller) + + +def load_and_setup_env( + console_port: int = 5554, + emulator_setup: bool = False, + freeze_datetime: bool = True, + adb_path: str = android_world_controller.DEFAULT_ADB_PATH, + grpc_port: int = 8554, +) -> interface.AsyncEnv: + """Create environment with `get_env()` and perform env setup and validation. + + Before running this, an emulator must be launched. For example: + + ``` + AVD_NAME=Pixel_6_API_33 # First create an AVD in Android Studio. + ~/Android/Sdk/emulator/emulator -avd $AVD_NAME -no-snapshot -grpc 8554 + ``` + + Args: + console_port: The console port of the existing device. This can usually be + retrieved by looking at the output of `adb devices`. In general, the first + connected device is port 5554, the second is 5556, and so on. + emulator_setup: Perform first-time app setup on the environment if True. + freeze_datetime: Whether to freeze the datetime to a fixed time, October + 2023, to ensure consistent benchmarking. + adb_path: The location of the adb binary. + grpc_port: The port for gRPC communication with the emulator. + + Returns: + An interactable Android environment. + """ + env = _get_env(console_port, adb_path, grpc_port) + setup_env(env, emulator_setup, freeze_datetime) + return env diff --git a/evaluation/AndroidWorld/android_world/env/env_launcher_test.py b/evaluation/AndroidWorld/android_world/env/env_launcher_test.py new file mode 100644 index 0000000..ba74a5b --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/env_launcher_test.py @@ -0,0 +1,64 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from android_env import env_interface +from android_env import loader +from android_env.components import config_classes +from android_world.env import android_world_controller +from android_world.env import env_launcher +from android_world.env import interface + + +class EnvLauncherTest(absltest.TestCase): + + @mock.patch.object(interface, "AsyncAndroidEnv", autospec=True) + @mock.patch.object( + android_world_controller, "AndroidWorldController", autospec=True + ) + @mock.patch.object(loader, "load", autospec=True) + def test_get_env( + self, + mock_loader, + mock_controller, + mock_async_android_env, + ): + mock_android_env = mock.create_autospec(env_interface.AndroidEnvInterface) + mock_loader.return_value = mock_android_env + + env_launcher._get_env(5556, "some_adb_path", 8554) + + mock_loader.assert_called_with( + config=config_classes.AndroidEnvConfig( + task=config_classes.FilesystemTaskConfig( + path=android_world_controller._TASK_PATH + ), + simulator=config_classes.EmulatorConfig( + emulator_launcher=config_classes.EmulatorLauncherConfig( + emulator_console_port=5556, adb_port=5557, grpc_port=8554 + ), + adb_controller=config_classes.AdbControllerConfig( + adb_path="some_adb_path" + ), + ), + ) + ) + mock_controller.assert_called_with(mock_android_env) + mock_async_android_env.assert_called_with(mock_controller.return_value) + + +if __name__ == "__main__": + absltest.main() diff --git a/evaluation/AndroidWorld/android_world/env/interface.py b/evaluation/AndroidWorld/android_world/env/interface.py new file mode 100644 index 0000000..508b874 --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/interface.py @@ -0,0 +1,354 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Environment interface for real-time interaction Android.""" + +import abc +import dataclasses +import time +from typing import Any, Optional, Self + +from android_env.components import action_type +from android_world.env import actuation +from android_world.env import adb_utils +from android_world.env import android_world_controller +from android_world.env import json_action +from android_world.env import representation_utils +import dm_env +import numpy as np + + +def _get_no_op_action() -> dict[str, Any]: + """Creates a no-op action; used to retrieve screen & UI tree.""" + return { + 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.int32), + 'touch_position': np.array((0.0, 0.0)), + } + + +@dataclasses.dataclass(frozen=True) +class State: + """State of the Android environment. + + Attributes: + pixels: RGB array of current screen. + forest: Raw UI forest; see android_world_controller.py for more info. + ui_elements: Processed children and stateful UI elements extracted from + forest. + auxiliaries: Additional information about the state. + """ + + pixels: np.ndarray + forest: Any + ui_elements: list[representation_utils.UIElement] + auxiliaries: dict[str, Any] | None = None + + @classmethod + def create_and_infer_elements( + cls, + pixels: np.ndarray, + forest: Any, + screen_size: Optional[tuple[int, int]] = None, + ) -> Self: + """Creates a new instance, inferring UI elements from the forest.""" + + elements = representation_utils.forest_to_ui_elements( + forest, screen_size=screen_size + ) + return cls(pixels, forest, elements) + + +class AsyncEnv(abc.ABC): + """Interface for interacting with a real-time Android device. + + Computing environments, such as Android, run in real-time, independently of + the agent interacting with it. All observations and actions are asynchronous + and OS does not pause when providing observations or when accepting actions. + Changes from action execution may take some time to appear. + """ + + @property + @abc.abstractmethod + def controller(self) -> android_world_controller.AndroidWorldController: + """Returns the controller for the environment.""" + + @abc.abstractmethod + def reset(self, go_home: bool = False) -> State: + """Go home on reset. + + Args: + go_home: Whether to go home during the reset. + """ + + @abc.abstractmethod + def get_state(self, wait_to_stabilize: bool = False) -> State: + """Gets the state of the environment; i.e., screenshot & UI tree. + + In practice this will usually be called after executing an action. Logic + should be implemented, perhaps a simple time.sleep, to ensure the + environment updates after the action. + + Args: + wait_to_stabilize: Whether to wait for the screen to stabilize before + returning state. + + Returns: + Observation containing RGB array of screen, the accessibility forest, + and UI elements derived from the forest. See android_world_controller.py + for + more detail. + """ + + def display_message(self, message: str, header: str = '') -> None: + """Displays a message on the screen.""" + + @abc.abstractmethod + def ask_question( + self, question: str, timeout_seconds: float = -1.0 + ) -> str | None: + """Asks a question to a hypothetical user in the environment. + + Common uses are to ask a question to clarify the user-provided goal, to ask + for help when the agent is stuck, or when there is ambiguity in the current + screen. + + Args: + question: The question to ask the user. + timeout_seconds: The timeout in seconds to wait for a response. If + negative, then wait indefinitely. + + Returns: + The response from the user or None if the user did not answer within the + timeout. + """ + + @abc.abstractmethod + def execute_action(self, action: json_action.JSONAction) -> None: + """Executes action on the environment.""" + + @property + @abc.abstractmethod + def foreground_activity_name(self) -> str: + """Returns the activity name of the app currently opened in foreground.""" + + @property + @abc.abstractmethod + def device_screen_size(self) -> tuple[int, int]: + """Returns the screen size of the environment in pixels: (width, height).""" + + @property + @abc.abstractmethod + def logical_screen_size(self) -> tuple[int, int]: + """Retrieves the logical screen size of the Android device. + + While the physical size is a fixed attribute of the display, the logical + size is flexible and varies based on system settings such as the orientation + or if the resolution is changed. + + Returns: The (width, height) in pixels, denoting the logical dimensions of + the screen. Width and height values are aligned with the device's current + orientation, meaning width is always logical horizontal direction (like in + the landscape orientation width will be the physical vertical direction). + """ + + @abc.abstractmethod + def close(self) -> None: + """Closes the environment.""" + + @property + @abc.abstractmethod + def interaction_cache(self) -> str: + """Returns the interaction cache of the environment.""" + + @abc.abstractmethod + def hide_automation_ui(self) -> None: + """Hides any UI, such as screen coordinates,.""" + + @property + @abc.abstractmethod + def orientation(self) -> int: + """Returns the orientation of the environment. + + Returns: 0 for portrait, 1 for landscape, 2 for reverse portrait, + 3 for reverse landscape. + """ + + @property + @abc.abstractmethod + def physical_frame_boundary(self) -> tuple[int, int, int, int]: + """Returns the physical frame boundary of the environment. + + Returns: First two integers are the coordinates for top left corner, last + two are for lower right corner. All coordinates are given in portrait + orientation. + """ + + +def _process_timestep(timestep: dm_env.TimeStep) -> State: + """Parses timestep observation and returns State.""" + return State( + pixels=timestep.observation['pixels'], + forest=timestep.observation[ + android_world_controller.OBSERVATION_KEY_FOREST + ], + ui_elements=timestep.observation[ + android_world_controller.OBSERVATION_KEY_UI_ELEMENTS + ], + auxiliaries={}, + ) + + +class AsyncAndroidEnv(AsyncEnv): + """Async environment interface using AndroidEnv to communicate with device.""" + + interaction_cache = '' + + def __init__( + self, controller: android_world_controller.AndroidWorldController + ): + self._controller = controller + self._prior_state = None + # Variable used to temporarily save interactions between agent and user. + # Like when agent use answer action to answer user questions, we + # use this to save the agent response. Or later on when agent has the + # ability to ask user question, user's answer will be saved here as well. + self.interaction_cache = '' + + @property + def controller(self) -> android_world_controller.AndroidWorldController: + return self._controller + + def reset(self, go_home: bool = False) -> State: + if go_home: + adb_utils.press_home_button(self.controller) + self.interaction_cache = '' + + return _process_timestep(self.controller.reset()) + + def _get_state(self): + return _process_timestep(self.controller.step(_get_no_op_action())) + + def _get_stable_state( + self, + stability_threshold: int = 3, + sleep_duration: float = 0.5, + timeout: float = 6.0, + ) -> State: + """Checks if the UI elements remain stable over a number of checks and returns the state. + + Args: + stability_threshold: Number of consecutive checks where UI elements must + remain the same to consider UI stable. + sleep_duration: Minimum time in seconds between each check. + timeout: Maximum time in seconds to wait for UI to become stable before + giving up. + + Returns: + The current state of the UI if stability is achieved within the timeout. + """ + if not self._prior_state: + self._prior_state = self._get_state() + if stability_threshold <= 0: + raise ValueError('Stability threshold must be a positive integer.') + + stable_checks = 1 + start_time = time.time() + deadline = start_time + timeout + + while stable_checks < stability_threshold and time.time() < deadline: + iteration_start_time = time.time() + current_state = self._get_state() + + if self._prior_state.ui_elements == current_state.ui_elements: + stable_checks += 1 + if stable_checks == stability_threshold: + break # Exit early if stability is achieved. + else: + stable_checks = 1 # Reset if any change is detected + self._prior_state = current_state + + elapsed_time = time.time() - iteration_start_time + remaining_sleep = sleep_duration - elapsed_time + if remaining_sleep > 0: + sleep_time = min(remaining_sleep, deadline - time.time()) + if sleep_time > 0: + time.sleep(sleep_time) + # If remaining_sleep <= 0, proceed immediately to the next iteration + + return current_state # pylint: disable=undefined-variable + + def get_state(self, wait_to_stabilize: bool = False) -> State: + if wait_to_stabilize: + return self._get_stable_state() + return self._get_state() + + def execute_action(self, action: json_action.JSONAction) -> None: + if action.action_type == json_action.ANSWER: + self.interaction_cache = action.text + if action.text: + self.display_message(action.text, header='Agent answered:') + return + state = self.get_state(wait_to_stabilize=False) + actuation.execute_adb_action( + action, + state.ui_elements, + self.logical_screen_size, + self.controller, + ) + + def hide_automation_ui(self) -> None: + """Hides the coordinates on screen.""" + adb_utils.issue_generic_request( + 'shell settings put system pointer_location 0', self.controller + ) + + def display_message(self, message: str, header: str = '') -> None: + adb_utils.send_android_intent( + command='broadcast', + action='com.example.ACTION_UPDATE_OVERLAY', + env=self.controller, + extras={'task_type_string': header, 'goal_string': message}, + ) + + def ask_question( + self, question: str, timeout_seconds: float = -1.0 + ) -> str | None: + raise NotImplementedError('ask_question is not implemented.') + + @property + def foreground_activity_name(self) -> str: + activity = adb_utils.get_current_activity(self.controller)[0] + if activity: + return activity + else: + return '' + + @property + def device_screen_size(self) -> tuple[int, int]: + return self.controller.device_screen_size + + @property + def logical_screen_size(self) -> tuple[int, int]: + return adb_utils.get_logical_screen_size(self.controller) + + def close(self) -> None: + return self.controller.close() + + @property + def orientation(self) -> int: + return adb_utils.get_orientation(self.controller) + + @property + def physical_frame_boundary(self) -> tuple[int, int, int, int]: + return adb_utils.get_physical_frame_boundary(self.controller) diff --git a/evaluation/AndroidWorld/android_world/env/interface_test.py b/evaluation/AndroidWorld/android_world/env/interface_test.py new file mode 100644 index 0000000..136747d --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/interface_test.py @@ -0,0 +1,91 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from android_world.env import interface +from android_world.env import representation_utils +import numpy as np + + +class InterfaceTest(absltest.TestCase): + + @mock.patch("time.sleep", return_value=None) + def test_ui_stability_true(self, unused_mocked_time_sleep): + stable_ui_elements = [representation_utils.UIElement(text="StableElement")] + states = [ + interface.State( + ui_elements=stable_ui_elements, + pixels=np.empty([1, 2, 3]), + forest=None, + ) + for _ in range(4) + ] + env = interface.AsyncAndroidEnv(mock.MagicMock()) + env._get_state = mock.MagicMock(side_effect=states) + + self.assertEqual( + env._get_stable_state( + stability_threshold=3, sleep_duration=0.1, timeout=1 + ), + states[2], + ) + + def test_ui_stability_false_due_to_timeout(self): + changing_ui_elements = [ + representation_utils.UIElement(text=f"Element{i}") for i in range(10) + ] + env = interface.AsyncAndroidEnv(mock.MagicMock()) + states = [ + interface.State( + ui_elements=[elem], pixels=np.empty([1, 2, 3]), forest=None + ) + for elem in changing_ui_elements + ] + env._get_state = mock.MagicMock(side_effect=states) + self.assertEqual( + env._get_stable_state( + stability_threshold=3, sleep_duration=0.1, timeout=0.41 + ), + states[5], + ) + + @mock.patch("time.sleep", return_value=None) + def test_stability_fluctuates(self, unused_mocked_time_sleep): + env = interface.AsyncAndroidEnv(mock.MagicMock()) + fluctuating_ui_elements = ( + [representation_utils.UIElement(text="Stable")] * 2 + + [representation_utils.UIElement(text="Unstable")] + + [representation_utils.UIElement(text="Stable")] * 3 + + [representation_utils.UIElement(text="Unstable")] + ) + states = [ + interface.State( + ui_elements=[elem], pixels=np.empty([1, 2, 3]), forest=None + ) + for elem in fluctuating_ui_elements + ] + env._get_state = mock.MagicMock(side_effect=states) + cur = env._get_stable_state( + stability_threshold=3, sleep_duration=0.5, timeout=2.5 + ) + self.assertEqual( + cur, + states[5], + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/evaluation/AndroidWorld/android_world/env/json_action.py b/evaluation/AndroidWorld/android_world/env/json_action.py new file mode 100644 index 0000000..802ce05 --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/json_action.py @@ -0,0 +1,182 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Represents an action for Android interaction, parsed from a JSON format.""" + +import dataclasses +import json +from typing import Optional + + +_JSON_SEPARATORS = (',', ':') + +ANSWER = 'answer' +CLICK = 'click' +DOUBLE_TAP = 'double_tap' +INPUT_TEXT = 'input_text' +KEYBOARD_ENTER = 'keyboard_enter' +LONG_PRESS = 'long_press' +NAVIGATE_BACK = 'navigate_back' +NAVIGATE_HOME = 'navigate_home' +OPEN_APP = 'open_app' +SCROLL = 'scroll' +STATUS = 'status' +SWIPE = 'swipe' +UNKNOWN = 'unknown' +WAIT = 'wait' + +_ACTION_TYPES = ( + CLICK, + DOUBLE_TAP, + SCROLL, + SWIPE, + INPUT_TEXT, + NAVIGATE_HOME, + NAVIGATE_BACK, + KEYBOARD_ENTER, + OPEN_APP, + STATUS, + WAIT, + LONG_PRESS, + ANSWER, + UNKNOWN, +) + +_SCROLL_DIRECTIONS = ('left', 'right', 'down', 'up') + +# Keys of JSON action. +ACTION_TYPE = 'action_type' +INDEX = 'index' +X = 'x' +Y = 'y' +TEXT = 'text' +DIRECTION = 'direction' +APP_NAME = 'app_name' +GOAL_STATUS = 'goal_status' + + +@dataclasses.dataclass() +class JSONAction: + """Represents a parsed JSON action. + + # Example + result_json = {'action_type': 'click', 'x': %d, 'y': %d} + action = JSONAction(**result_json) + + Attributes: + action_type: The action type. + index: The index to click, if action is a click. Either an index or a + should be provided. See x, y attributes below. + x: The x position to click, if the action is a click. + y: The y position to click, if the action is a click. + text: The text to type, if action is type. + direction: The direction to scroll, if action is scroll. + goal_status: If the status is a 'status' type, indicates the status of the + goal. + app_name: The app name to launch, if the action type is 'open_app'. + keycode: Keycode actions are necessary for an agent to interact with complex + UI elements (like large textareas) that can't be accessed or controlled by + simply taping, ensuring precise control over navigation and selection in + the interface. + """ + + action_type: Optional[str] = None + index: Optional[str | int] = None + element: Optional[str] = None + x: Optional[int] = None + y: Optional[int] = None + text: Optional[str] = None + direction: Optional[str] = None + goal_status: Optional[str] = None + app_name: Optional[str] = None + keycode: Optional[str] = None + + def __post_init__(self): + if self.action_type not in _ACTION_TYPES: + raise ValueError(f'Invalid action type: {self.action_type}') + # if self.index is not None: + # self.index = int(self.index) + # if self.x is not None or self.y is not None: + # raise ValueError('Either an index or a should be provided.') + + + if self.element is not None: + self.element = str(self.element) + if self.x is not None or self.y is not None: + raise ValueError('Either an index or a should be provided.') + + if self.direction and self.direction not in _SCROLL_DIRECTIONS: + raise ValueError(f'Invalid scroll direction: {self.direction}') + if self.text is not None and not isinstance(self.text, str): + self.text = str(self.text) + if self.keycode is not None and not self.keycode.startswith('KEYCODE_'): + raise ValueError(f'Invalid keycode: {self.keycode}') + + def __repr__(self) -> str: + properties = [] + for key, value in self.__dict__.items(): + if value is not None: + if isinstance(value, float): + value = f'{value:.3f}' + properties.append(f'{key}={value!r}') + return f"JSONAction({', '.join(properties)})" + + def __eq__(self, other): + if isinstance(other, JSONAction): + return _compare_actions(self, other) + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def json_str(self) -> str: + non_null = {} + for key, value in self.__dict__.items(): + if value is not None: + non_null[key] = value + return json.dumps(non_null, separators=_JSON_SEPARATORS) + + +def _compare_actions(a: JSONAction, b: JSONAction) -> bool: + """Compares two JSONActions. + + Args: + a: The first action. + b: The second action. + + Returns: + If the actions are equal. + """ + # Ignore cases. + if a.app_name is not None and b.app_name is not None: + app_name_match = a.app_name.lower() == b.app_name.lower() + else: + app_name_match = a.app_name == b.app_name + + if a.text is not None and b.text is not None: + text_match = a.text.lower() == b.text.lower() + else: + text_match = a.text == b.text + + # Compare the non-metadata fields. + return ( + app_name_match + and text_match + and a.action_type == b.action_type + and a.x == b.x + and a.y == b.y + and a.keycode == b.keycode + and a.direction == b.direction + and a.goal_status == b.goal_status + ) diff --git a/evaluation/AndroidWorld/android_world/env/representation_utils.py b/evaluation/AndroidWorld/android_world/env/representation_utils.py new file mode 100644 index 0000000..a9c5912 --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/representation_utils.py @@ -0,0 +1,219 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for processing and representing accessibility trees.""" + +import dataclasses +from typing import Any, Optional +import xml.etree.ElementTree as ET +from android_env.proto.a11y import android_accessibility_forest_pb2 + + +@dataclasses.dataclass +class BoundingBox: + """Class for representing a bounding box.""" + + x_min: float | int + x_max: float | int + y_min: float | int + y_max: float | int + + @property + def center(self) -> tuple[float, float]: + """Gets center of bounding box.""" + return (self.x_min + self.x_max) / 2.0, (self.y_min + self.y_max) / 2.0 + + @property + def width(self) -> float | int: + """Gets width of bounding box.""" + return self.x_max - self.x_min + + @property + def height(self) -> float | int: + """Gets height of bounding box.""" + return self.y_max - self.y_min + + @property + def area(self) -> float | int: + return self.width * self.height + + +@dataclasses.dataclass +class UIElement: + """Represents a UI element.""" + + text: Optional[str] = None + content_description: Optional[str] = None + class_name: Optional[str] = None + bbox: Optional[BoundingBox] = None + bbox_pixels: Optional[BoundingBox] = None + hint_text: Optional[str] = None + is_checked: Optional[bool] = None + is_checkable: Optional[bool] = None + is_clickable: Optional[bool] = None + is_editable: Optional[bool] = None + is_enabled: Optional[bool] = None + is_focused: Optional[bool] = None + is_focusable: Optional[bool] = None + is_long_clickable: Optional[bool] = None + is_scrollable: Optional[bool] = None + is_selected: Optional[bool] = None + is_visible: Optional[bool] = None + package_name: Optional[str] = None + resource_name: Optional[str] = None + tooltip: Optional[str] = None + resource_id: Optional[str] = None + metadata: Optional[dict[str, Any]] = None + + +def accessibility_node_to_ui_element( + node: Any, + screen_size: Optional[tuple[int, int]] = None, +) -> UIElement: + """Converts a node from an accessibility tree to a UIElement.""" + + def text_or_none(text: Optional[str]) -> Optional[str]: + """Returns None if text is None or 0 length.""" + return text if text else None + + node_bbox = node.bounds_in_screen + bbox_pixels = BoundingBox( + node_bbox.left, node_bbox.right, node_bbox.top, node_bbox.bottom + ) + + if screen_size is not None: + bbox_normalized = _normalize_bounding_box(bbox_pixels, screen_size) + else: + bbox_normalized = None + + return UIElement( + text=text_or_none(node.text), + content_description=text_or_none(node.content_description), + class_name=text_or_none(node.class_name), + bbox=bbox_normalized, + bbox_pixels=bbox_pixels, + hint_text=text_or_none(node.hint_text), + is_checked=node.is_checked, + is_checkable=node.is_checkable, + is_clickable=node.is_clickable, + is_editable=node.is_editable, + is_enabled=node.is_enabled, + is_focused=node.is_focused, + is_focusable=node.is_focusable, + is_long_clickable=node.is_long_clickable, + is_scrollable=node.is_scrollable, + is_selected=node.is_selected, + is_visible=node.is_visible_to_user, + package_name=text_or_none(node.package_name), + resource_name=text_or_none(node.view_id_resource_name), + ) + + +def _normalize_bounding_box( + node_bbox: BoundingBox, + screen_width_height_px: tuple[int, int], +) -> BoundingBox: + width, height = screen_width_height_px + return BoundingBox( + node_bbox.x_min / width, + node_bbox.x_max / width, + node_bbox.y_min / height, + node_bbox.y_max / height, + ) + + +def forest_to_ui_elements( + forest: android_accessibility_forest_pb2.AndroidAccessibilityForest | Any, + exclude_invisible_elements: bool = False, + screen_size: Optional[tuple[int, int]] = None, +) -> list[UIElement]: + """Extracts nodes from accessibility forest and converts to UI elements. + + We extract all nodes that are either leaf nodes or have content descriptions + or is scrollable. + + Args: + forest: The forest to extract leaf nodes from. + exclude_invisible_elements: True if invisible elements should not be + returned. + screen_size: The size of the device screen in pixels (width, height). + + Returns: + The extracted UI elements. + """ + elements = [] + for window in forest.windows: + for node in window.tree.nodes: + if not node.child_ids or node.content_description or node.is_scrollable: + if exclude_invisible_elements and not node.is_visible_to_user: + continue + else: + elements.append(accessibility_node_to_ui_element(node, screen_size)) + return elements + + +def _parse_ui_hierarchy(xml_string: str) -> dict[str, Any]: + """Parses the UI hierarchy XML into a dictionary structure.""" + root = ET.fromstring(xml_string) + + def parse_node(node): + result = node.attrib + result['children'] = [parse_node(child) for child in node] + return result + + return parse_node(root) + + +def xml_dump_to_ui_elements(xml_string: str) -> list[UIElement]: + """Converts a UI hierarchy XML dump from uiautomator dump to UIElements.""" + parsed_hierarchy = _parse_ui_hierarchy(xml_string) + ui_elements = [] + + def process_node(node, is_root): + bounds = node.get('bounds') + if bounds: + x_min, y_min, x_max, y_max = map( + int, bounds.strip('[]').replace('][', ',').split(',') + ) + bbox = BoundingBox(x_min, x_max, y_min, y_max) + else: + bbox = None + + ui_element = UIElement( + text=node.get('text'), + content_description=node.get('content-desc'), + class_name=node.get('class'), + bbox=bbox, + bbox_pixels=bbox, + is_checked=node.get('checked') == 'true', + is_checkable=node.get('checkable') == 'true', + is_clickable=node.get('clickable') == 'true', + is_enabled=node.get('enabled') == 'true', + is_focused=node.get('focused') == 'true', + is_focusable=node.get('focusable') == 'true', + is_long_clickable=node.get('long-clickable') == 'true', + is_scrollable=node.get('scrollable') == 'true', + is_selected=node.get('selected') == 'true', + package_name=node.get('package'), + resource_id=node.get('resource-id'), + is_visible=True, + ) + if not is_root: + ui_elements.append(ui_element) + + for child in node.get('children', []): + process_node(child, is_root=False) + + process_node(parsed_hierarchy, is_root=True) + return ui_elements diff --git a/evaluation/AndroidWorld/android_world/env/representation_utils_test.py b/evaluation/AndroidWorld/android_world/env/representation_utils_test.py new file mode 100644 index 0000000..0ee3373 --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/representation_utils_test.py @@ -0,0 +1,119 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from android_world.env import representation_utils + + +@dataclasses.dataclass(frozen=True) +class BoundsInScreen: + left: int + right: int + top: int + bottom: int + + +class TestAccessibilityNodeToUIElement(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='rectangle_to_rectangle_normalization', + node_bounds=BoundsInScreen(0, 150, 0, 100), + screen_size=(500, 500), + expected_normalized_bbox=representation_utils.BoundingBox( + 0.0, 0.3, 0.0, 0.2 + ), + ), + dict( + testcase_name='square_to_square_normalization', + node_bounds=BoundsInScreen(100, 200, 100, 200), + screen_size=(1000, 1000), + expected_normalized_bbox=representation_utils.BoundingBox( + 0.1, 0.2, 0.1, 0.2 + ), + ), + dict( + testcase_name='square_to_rectangle_normalization', + node_bounds=BoundsInScreen(0, 100, 0, 100), + screen_size=(1000, 500), + expected_normalized_bbox=representation_utils.BoundingBox( + 0.0, 0.1, 0.0, 0.2 + ), + ), + dict( + testcase_name='no_change_square_normalization', + node_bounds=BoundsInScreen(0, 100, 0, 100), + screen_size=(100, 100), + expected_normalized_bbox=representation_utils.BoundingBox( + 0.0, 1.0, 0.0, 1.0 + ), + ), + dict( + testcase_name='no_change_rectangle_normalization', + node_bounds=BoundsInScreen(0, 200, 0, 100), + screen_size=(200, 100), + expected_normalized_bbox=representation_utils.BoundingBox( + 0.0, 1.0, 0.0, 1.0 + ), + ), + dict( + testcase_name='normalization_causing_dimensions_to_grow', + node_bounds=BoundsInScreen(0, 50, 0, 50), + screen_size=(200, 200), + expected_normalized_bbox=representation_utils.BoundingBox( + 0.0, 0.25, 0.0, 0.25 + ), + ), + dict( + testcase_name='zero_size_bbox_normalization', + node_bounds=BoundsInScreen(0, 0, 0, 0), + screen_size=(100, 100), + expected_normalized_bbox=representation_utils.BoundingBox( + 0.0, 0.0, 0.0, 0.0 + ), + ), + dict( + testcase_name='no_normalization', + node_bounds=BoundsInScreen(10, 20, 11, 13), + screen_size=None, + expected_normalized_bbox=None, + ), + ) + def test_normalize_bboxes( + self, node_bounds, screen_size, expected_normalized_bbox + ): + node = mock.MagicMock() + node.bounds_in_screen = node_bounds + + ui_element = representation_utils.accessibility_node_to_ui_element( + node, screen_size + ) + self.assertEqual(ui_element.bbox_pixels.x_min, node_bounds.left) + self.assertEqual(ui_element.bbox_pixels.x_max, node_bounds.right) + self.assertEqual(ui_element.bbox_pixels.y_min, node_bounds.top) + self.assertEqual(ui_element.bbox_pixels.y_max, node_bounds.bottom) + + if screen_size is not None: + ui_element.bbox = representation_utils._normalize_bounding_box( + ui_element.bbox_pixels, screen_size + ) + self.assertEqual(ui_element.bbox, expected_normalized_bbox) + + +if __name__ == '__main__': + absltest.main() diff --git a/evaluation/AndroidWorld/android_world/env/setup_device/apps.py b/evaluation/AndroidWorld/android_world/env/setup_device/apps.py new file mode 100644 index 0000000..a219c01 --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/setup_device/apps.py @@ -0,0 +1,757 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module defines classes for setting up various applications in the Android World environment. + +Each class represents an app and includes methods for retrieving its APK name +and performing setup tasks specific to that app using the Android Environment +Interface. +""" + +import abc +import os +import time +from typing import Iterable +from absl import logging +from android_world.env import adb_utils +from android_world.env import interface +from android_world.env import tools +from android_world.task_evals.information_retrieval import joplin_app_utils +from android_world.utils import file_utils +import requests + + +APP_DATA = os.path.join(os.path.dirname(__file__), 'app_data') + + +def download_app_data(file_name: str) -> str: + """Downloads file from a GCS bucket, if not cached, and installs it.""" + cache_dir = "/tmp/android_world/app_data" + remote_url = ( + f"https://storage.googleapis.com/gresearch/android_world/{file_name}" + ) + full_path = os.path.join(cache_dir, file_name) + os.makedirs(cache_dir, exist_ok=True) + if not os.path.isfile(full_path): + logging.info("Downloading file_name %s to cache %s", file_name, cache_dir) + response = requests.get(remote_url) + if response.status_code == 200: + with open(full_path, "wb") as file: + file.write(response.content) + else: + raise RuntimeError( + f"Failed to download file_name from {remote_url}, status code:" + f" {response.status_code}" + ) + else: + logging.info("File already %s exists in cache %s", file_name, cache_dir) + return full_path + + +class AppSetup(abc.ABC): + """Abstract class for setting up an app.""" + + # The APK name of the app. This will assumed to be downloaded in setup.py and + # each instance of an AppSetup will be referenced using the `apk` name as the + # key for downloading. Some apps contain multiple APK names since different + # versions are distributed depending on the architecture. E.g., M1 Macs + # require different APKs for some apps. + apk_names = "" + + # The short name of the app, as used by adb_utils. + app_name = "" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + """Performs setup tasks specific to the app.""" + adb_utils.clear_app_data( + adb_utils.extract_package_name( + adb_utils.get_adb_activity(cls.app_name) + ), + env.controller, + ) + + @classmethod + def _copy_data_to_device( + cls, + files: Iterable[str], + device_path: str, + env: interface.AsyncEnv, + ) -> None: + """Helper method for copying app data to the device. + + Args: + files: Names of files to copy from {APP_DATA}/app_name/ to {device_path}. + device_path: Location on device to load the files. + env: Android environment. + """ + for file in files: + copy_to_device = lambda path: adb_utils.check_ok( + file_utils.copy_data_to_device( + path, + device_path, + env.controller, + ), + f"Failed to copy {device_path} to device.", + ) + + full_path = download_app_data(file) + copy_to_device(full_path) + + +class CameraApp(AppSetup): + """Class for setting up pre-installed Camera app.""" + + app_name = "camera" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + + # Prevent pop-up asking for permission. + adb_utils.grant_permissions( + adb_utils.extract_package_name( + adb_utils.get_adb_activity(cls.app_name) + ), + "android.permission.ACCESS_COARSE_LOCATION", + env.controller, + ) + + # Click through onboarding screens during first time launch. + adb_utils.launch_app(cls.app_name, env.controller) + try: + controller = tools.AndroidToolController(env=env.controller) + time.sleep(2.0) + controller.click_element("NEXT") + time.sleep(2.0) + finally: + adb_utils.close_app(cls.app_name, env.controller) + + +class ChromeApp(AppSetup): + """Class for setting up pre-installed Chrome app.""" + + app_name = "chrome" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + + # Click through onboarding screens during first time launch. + adb_utils.launch_app(cls.app_name, env.controller) + try: + controller = tools.AndroidToolController(env=env.controller) + time.sleep(2.0) + # Welcome screen. + controller.click_element("Accept & continue") + time.sleep(2.0) + # Turn on sync? + controller.click_element("No thanks") + time.sleep(2.0) + # Enable notifications? + controller.click_element("No thanks") + time.sleep(2.0) + finally: + adb_utils.close_app(cls.app_name, env.controller) + + +class ClockApp(AppSetup): + """Class for setting up pre-installed Clock app.""" + + app_name = "clock" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + + # Open once for initial tool tip display. + adb_utils.launch_app(cls.app_name, env.controller) + time.sleep(2.0) + adb_utils.close_app(cls.app_name, env.controller) + + +class ContactsApp(AppSetup): + """Class for setting up pre-installed Contacts app.""" + + app_name = "contacts" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + + # Click through onboarding screens during first time launch. + adb_utils.launch_app(cls.app_name, env.controller) + try: + controller = tools.AndroidToolController(env=env.controller) + time.sleep(2.0) + # Back up & organize your contacts with Google. + controller.click_element("Skip") + time.sleep(2.0) + # Allow Contacts to send you notifications? + controller.click_element("Don't allow") + time.sleep(2.0) + finally: + adb_utils.close_app(cls.app_name, env.controller) + + +class DialerApp(AppSetup): + """Class for setting up pre-installed Dialer app.""" + + app_name = "dialer" + + +class FilesApp(AppSetup): + """Class for setting up pre-installed Files app.""" + + app_name = "files" + + +class SettingsApp(AppSetup): + """Class for setting up pre-installed Settings app.""" + + app_name = "settings" + + +class MarkorApp(AppSetup): + """Class for setting up Markor app.""" + + apk_names = ("net.gsantner.markor_146.apk",) + app_name = "markor" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + + adb_utils.launch_app(cls.app_name, env.controller) + try: + controller = tools.AndroidToolController(env=env.controller) + time.sleep(2.0) + controller.click_element("NEXT") + time.sleep(2.0) + controller.click_element("NEXT") + time.sleep(2.0) + controller.click_element("NEXT") + time.sleep(2.0) + controller.click_element("NEXT") + time.sleep(2.0) + controller.click_element("DONE") + time.sleep(2.0) + + controller.click_element("OK") + time.sleep(2.0) + controller.click_element("Allow access to manage all files") + time.sleep(2.0) + finally: + adb_utils.close_app(cls.app_name, env.controller) + + +class AndroidWorldApp(AppSetup): + """Class for setting up Android World app. + + AndroidWorld app provides on-screen visualization of tasks and rewards. + """ + + apk_names = ("androidworld.apk",) + app_name = "android world" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + adb_utils.issue_generic_request( + [ + "shell", + "appops", + "set", + adb_utils.extract_package_name( + adb_utils.get_adb_activity("android world") + ), + "android:system_alert_window", + "allow", + ], + env.controller, + ) + adb_utils.launch_app(cls.app_name, env.controller) + adb_utils.close_app(cls.app_name, env.controller) + + +class ClipperApp(AppSetup): + """Class for setting up clipper app.""" + + apk_names = ("clipper.apk",) + app_name = "clipper" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + controller = tools.AndroidToolController(env=env.controller) + adb_utils.launch_app(cls.app_name, env.controller) + try: + time.sleep(2.0) + controller.click_element("Continue") + time.sleep(2.0) + controller.click_element("OK") + finally: + adb_utils.close_app(cls.app_name, env.controller) + + +class SimpleCalendarProApp(AppSetup): + """Class for setting up simple calendar pro app.""" + + apk_names = ("com.simplemobiletools.calendar.pro_238.apk",) + app_name = "simple calendar pro" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + adb_utils.launch_app(cls.app_name, env.controller) + adb_utils.close_app(cls.app_name, env.controller) + + # Grant permissions for calendar app. + calendar_package = adb_utils.extract_package_name( + adb_utils.get_adb_activity("simple calendar pro") + ) + adb_utils.grant_permissions( + calendar_package, + "android.permission.READ_CALENDAR", + env.controller, + ) + adb_utils.grant_permissions( + calendar_package, + "android.permission.WRITE_CALENDAR", + env.controller, + ) + adb_utils.grant_permissions( + calendar_package, + "android.permission.POST_NOTIFICATIONS", + env.controller, + ) + + +class TasksApp(AppSetup): + """Class for setting up Tasks app.""" + + apk_names = ("org.tasks_130605.apk",) + app_name = "tasks" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + adb_utils.launch_app(cls.app_name, env.controller) + adb_utils.close_app(cls.app_name, env.controller) + + +class SimpleDrawProApp(AppSetup): + """Class for setting up simple draw pro app.""" + + apk_names = ("com.simplemobiletools.draw.pro_79.apk",) + app_name = "simple draw pro" + + +class SimpleGalleryProApp(AppSetup): + """Class for setting up Simple Gallery Pro app.""" + + PERMISSIONS = ( + "android.permission.WRITE_EXTERNAL_STORAGE", + "android.permission.ACCESS_MEDIA_LOCATION", + "android.permission.READ_MEDIA_IMAGES", + "android.permission.READ_MEDIA_VIDEO", + "android.permission.POST_NOTIFICATIONS", + ) + + apk_names = ("com.simplemobiletools.gallery.pro_396.apk",) + app_name = "simple gallery pro" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + + # Grant permissions for gallery app. + package = adb_utils.extract_package_name( + adb_utils.get_adb_activity(cls.app_name) + ) + for permission in cls.PERMISSIONS: + adb_utils.grant_permissions(package, permission, env.controller) + + adb_utils.launch_app("simple gallery pro", env.controller) + try: + controller = tools.AndroidToolController(env=env.controller) + time.sleep(2.0) + controller.click_element("All files") + time.sleep(2.0) + controller.click_element("Allow access to manage all files") + finally: + adb_utils.close_app(cls.app_name, env.controller) + + +class SimpleSMSMessengerApp(AppSetup): + """Class for setting up Simple SMS Messenger app.""" + + apk_names = ("com.simplemobiletools.smsmessenger_85.apk",) + app_name = "simple sms messenger" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + + # Make Simple Messenger the default SMS app. + adb_utils.set_default_app( + "sms_default_application", + adb_utils.extract_package_name( + adb_utils.get_adb_activity("simple sms messenger") + ), + env.controller, + ) + + adb_utils.launch_app(cls.app_name, env.controller) + try: + controller = tools.AndroidToolController(env=env.controller) + time.sleep(2.0) + controller.click_element("SMS Messenger") + time.sleep(2.0) + controller.click_element("Set as default") + finally: + adb_utils.close_app(cls.app_name, env.controller) + + +class AudioRecorder(AppSetup): + """Class for setting up Audio Recorder app.""" + + apk_names = ("com.dimowner.audiorecorder_926.apk",) + app_name = "audio recorder" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + adb_utils.grant_permissions( + "com.dimowner.audiorecorder", + "android.permission.RECORD_AUDIO", + env.controller, + ) + adb_utils.grant_permissions( + "com.dimowner.audiorecorder", + "android.permission.POST_NOTIFICATIONS", + env.controller, + ) + + # Launch the app + adb_utils.issue_generic_request( + [ + "shell", + "monkey", + "-p", + "com.dimowner.audiorecorder", + "-candroid.intent.category.LAUNCHER", + "1", + ], + env.controller, + ) + time.sleep(2.0) # Let app setup. + adb_utils.close_app(cls.app_name, env.controller) + + +class MiniWobApp(AppSetup): + """Class for setting up MiniWoB app.""" + + apk_names = ("miniwobapp.apk",) + app_name = "miniwob" + + +class ExpenseApp(AppSetup): + """Class for setting up Arduia Pro Expense app.""" + + apk_names = ("com.arduia.expense_11.apk",) + app_name = "pro expense" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + adb_utils.launch_app(cls.app_name, env.controller) + try: + time.sleep(2.0) + controller = tools.AndroidToolController(env=env.controller) + controller.click_element("NEXT") + time.sleep(2.0) + controller.click_element("CONTINUE") + time.sleep(3.0) + finally: + adb_utils.close_app(cls.app_name, env.controller) + + +class RecipeApp(AppSetup): + """Class for setting up Broccoli Recipe app.""" + + apk_names = ("com.flauschcode.broccoli_1020600.apk",) + app_name = "broccoli app" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + adb_utils.launch_app(cls.app_name, env.controller) + time.sleep(2.0) + adb_utils.close_app(cls.app_name, env.controller) + + +class OsmAndApp(AppSetup): + """Class for setting up OsmAndApp map app. + + OsmAnd handles the following intents (among others*). In addition to geo + URIs, it can handle intents using the Google Maps API as well as a few + other apps not listed here. + + Android geo intents: + geo:latitude,longitude + geo:latitude,longitude?z=zoom + geo:0,0?q=my+street+address + geo:0,0?q=business+near+city + + OsmAnd specific intents: + http://download.osmand.net/go?lat=&lon=&z= + http://osmand.net/go?lat=34&lon=-106&z=11 + + Google: + google.navigation:q=34.99393,-106.61568 + http://maps.google.com/maps?q=N34.939,W106 + http://maps.google.com/maps?f=d&saddr=My+Location&daddr=lat,lon + http://maps.google.com/maps/@34,-106,11z + http://maps.google.com/maps/ll=34.99393,-106.61568,z=11 + https://maps.google.com/maps?q=loc:-21.8835112,-47.7838932 (Name) + http://maps.google.com/maps?q=34,-106 + http://www.google.com/maps/dir/Current+Location/34,-106 + + * https://osmand.net/docs/technical/algorithms/osmand-intents/ + """ + + PERMISSIONS = ( + "android.permission.POST_NOTIFICATIONS", + # For other possible permissions see the manifest + # https://github.com/osmandapp/OsmAnd/blob/master/OsmAnd/AndroidManifest.xml + ) + + DEVICE_MAPS_PATH = "/storage/emulated/0/Android/data/net.osmand/files/" + + MAP_NAMES = ("Liechtenstein_europe.obf",) + + apk_names = ("net.osmand-4.6.13.apk",) + app_name = "osmand" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + adb_utils.launch_app(cls.app_name, env.controller) + time.sleep(2.0) + + try: + controller = tools.AndroidToolController(env=env.controller) + controller.click_element("SKIP DOWNLOAD") + time.sleep(2.0) + except ValueError: + logging.warn( + "First time setup did not click through all anticipated screens." + ) + finally: + adb_utils.close_app(cls.app_name, env.controller) + + # Grant permissions for OsmAnd mapping app. + package = adb_utils.extract_package_name( + adb_utils.get_adb_activity(cls.app_name) + ) + for permission in cls.PERMISSIONS: + adb_utils.grant_permissions(package, permission, env.controller) + + # Copy maps to data directory. + cls._copy_data_to_device(cls.MAP_NAMES, cls.DEVICE_MAPS_PATH, env) + + # Make sure security context is correct so that the files can be accessed. + for map_file in cls.MAP_NAMES: + adb_utils.check_ok( + adb_utils.issue_generic_request( + [ + "shell", + "chcon", + "u:object_r:media_rw_data_file:s0", + os.path.join(cls.DEVICE_MAPS_PATH, map_file), + ], + env.controller, + ) + ) + + adb_utils.close_app(cls.app_name, env.controller) + + +class OpenTracksApp(AppSetup): + """Class for setting up OpenTracks app.""" + + apk_names = ("de.dennisguse.opentracks_5705.apk",) + app_name = "open tracks sports tracker" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + adb_utils.launch_app(cls.app_name, env.controller) + adb_utils.close_app(cls.app_name, env.controller) + + # Grant permissions for open tracks app. + open_tracks_package = adb_utils.extract_package_name( + adb_utils.get_adb_activity("open tracks") + ) + adb_utils.grant_permissions( + open_tracks_package, + "android.permission.ACCESS_COARSE_LOCATION", + env.controller, + ) + adb_utils.grant_permissions( + open_tracks_package, + "android.permission.ACCESS_FINE_LOCATION", + env.controller, + ) + adb_utils.grant_permissions( + open_tracks_package, + "android.permission.POST_NOTIFICATIONS", + env.controller, + ) + time.sleep(2.0) + controller = tools.AndroidToolController(env=env.controller) + # Give permission for bluetooth, can't be done through adb. + controller.click_element("Allow") + adb_utils.launch_app("activity tracker", env.controller) + adb_utils.close_app("activity tracker", env.controller) + + +class VlcApp(AppSetup): + """Class for setting up VLC app.""" + + videos_path = "/storage/emulated/0/VLCVideos" # Store videos here. + apk_names = ( + "org.videolan.vlc_13050408.apk", + "org.videolan.vlc_13050407.apk", # Arch86 for Mac M1/M2/etc. + ) + app_name = "vlc" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + package = adb_utils.extract_package_name( + adb_utils.get_adb_activity(cls.app_name) + ) + adb_utils.grant_permissions( + package, "android.permission.POST_NOTIFICATIONS", env.controller + ) + if not file_utils.check_directory_exists(cls.videos_path, env.controller): + file_utils.mkdir(cls.videos_path, env.controller) + + time.sleep(2.0) + # Launch similar to opening app from app launcher. This runs setup logic not + # available using `adb shell am start`. Specifically, it will create the + # /data/data/org.videolan.vlc/app_db/vlc_media.db file. + adb_utils.issue_generic_request( + [ + "shell", + "monkey", + "-p", + package, + "-candroid.intent.category.LAUNCHER", + "1", + ], + env.controller, + ) + time.sleep(2.0) + try: + controller = tools.AndroidToolController(env=env.controller) + controller.click_element("Skip") + time.sleep(2.0) + controller.click_element("GRANT PERMISSION") + time.sleep(2.0) + controller.click_element("OK") + time.sleep(2.0) + controller.click_element("Allow access to manage all files") + finally: + adb_utils.close_app(cls.app_name, env.controller) + + +class JoplinApp(AppSetup): + """Class for setting up Joplin app.""" + + apk_names = ("net.cozic.joplin_2097740.apk",) + app_name = "joplin" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + + # Grant permissions for joplin app. + joplin_package = adb_utils.extract_package_name( + adb_utils.get_adb_activity(cls.app_name) + ) + adb_utils.grant_permissions( + joplin_package, + "android.permission.ACCESS_COARSE_LOCATION", + env.controller, + ) + adb_utils.grant_permissions( + joplin_package, + "android.permission.ACCESS_FINE_LOCATION", + env.controller, + ) + + # Launch the app, similar to how user launches it from App Drawer. + adb_utils.issue_generic_request( + [ + "shell", + "monkey", + "-p", + joplin_package, + "-candroid.intent.category.LAUNCHER", + "1", + ], + env.controller, + ) + time.sleep(10.0) + adb_utils.close_app(cls.app_name, env.controller) + time.sleep(10.0) + + # Calling clear_dbs() without having added a note seems to make + # the sqlite table inaccessible. Every subsequent call to clear_dbs() + # works fine. + joplin_app_utils.create_note( + folder="new folder", + title="new_note", + body="", + folder_mapping={}, + env=env, + ) + joplin_app_utils.clear_dbs(env) + + +class RetroMusicApp(AppSetup): + """Class for setting up Retro Music.""" + + PERMISSIONS = ( + "android.permission.READ_MEDIA_AUDIO", + "android.permission.POST_NOTIFICATIONS", + ) + + apk_names = ("code.name.monkey.retromusic_10603.apk",) + app_name = "retro music" + + @classmethod + def setup(cls, env: interface.AsyncEnv) -> None: + super().setup(env) + package = adb_utils.extract_package_name( + adb_utils.get_adb_activity("retro music") + ) + for permission in cls.PERMISSIONS: + adb_utils.grant_permissions(package, permission, env.controller) + + adb_utils.launch_app(cls.app_name, env.controller) + time.sleep(2.0) + adb_utils.close_app(cls.app_name, env.controller) diff --git a/evaluation/AndroidWorld/android_world/env/setup_device/setup.py b/evaluation/AndroidWorld/android_world/env/setup_device/setup.py new file mode 100644 index 0000000..68d9dcd --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/setup_device/setup.py @@ -0,0 +1,134 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Setup tool for Android World. + +It does the following: + +* APK Management: Automates installations of apks needed for Android World. +* Sets up environment: Configures emulator with necessary permissions, using adb + and basic automation. +""" + +import os +from typing import Type + +from absl import logging +from android_env import env_interface +from android_env.components import errors +from android_world.env import adb_utils +from android_world.env import interface +from android_world.env.setup_device import apps +from android_world.utils import app_snapshot + +# APKs required for Android World. +_APPS = ( + # keep-sorted start + apps.AndroidWorldApp, + apps.AudioRecorder, + apps.CameraApp, + apps.ChromeApp, + apps.ClipperApp, + apps.ClockApp, + apps.ContactsApp, + apps.DialerApp, + apps.ExpenseApp, + apps.FilesApp, + apps.JoplinApp, + apps.MarkorApp, + apps.MiniWobApp, + apps.OpenTracksApp, + apps.OsmAndApp, + apps.RecipeApp, + apps.RetroMusicApp, + apps.SettingsApp, + apps.SimpleCalendarProApp, + apps.SimpleDrawProApp, + apps.SimpleGalleryProApp, + apps.SimpleSMSMessengerApp, + apps.TasksApp, + apps.VlcApp, + # keep-sorted end +) + + +def get_app_mapping(app_name: str) -> Type[apps.AppSetup]: + mapping = {app.app_name: app for app in _APPS} + return mapping[app_name] + + +def download_and_install_apk( + apk: str, raw_env: env_interface.AndroidEnvInterface +) -> None: + """Downloads APK from remote location and installs it.""" + path = apps.download_app_data(apk) + adb_utils.install_apk(path, raw_env) + + +def setup_app(app: Type[apps.AppSetup], env: interface.AsyncEnv) -> None: + """Sets up a single app.""" + try: + logging.info("Setting up app %s", app.app_name) + app.setup(env) + except ValueError as e: + logging.warning( + "Failed to automatically setup app %s: %s.\n\nYou will need to" + " manually setup the app.", + app.app_name, + e, + ) + app_snapshot.save_snapshot(app.app_name, env.controller) + + +def maybe_install_app( + app: Type[apps.AppSetup], env: interface.AsyncEnv +) -> None: + """Installs all APKs for Android World.""" + if not app.apk_names: # Ignore 1p apps that don't have an APK. + return + print(f"Installing app: {app.app_name}.") + apk_installed = False + for apk_name in app.apk_names: + try: + download_and_install_apk(apk_name, env.controller.env) + apk_installed = True + break + except errors.AdbControllerError: + # Try apk compiled for a different architecture, e.g., Mac M1. + continue + if not apk_installed: + raise RuntimeError(f"Failed to download and install APK for {app.app_name}") + + +def setup_apps(env: interface.AsyncEnv) -> None: + """Sets up apps for Android World. + + Args: + env: The Android environment. + + Raises: + RuntimeError: If cannot install APK. + """ + # Make sure quick-settings are not displayed, which can override foreground + # apps, and impede UI navigation required for setting up. + adb_utils.press_home_button(env.controller) + adb_utils.set_root_if_needed(env.controller) + + print( + "Installing and setting up applications on Android device. Please do not" + " interact with device while installation is running." + ) + for app in _APPS: + maybe_install_app(app, env) + setup_app(app, env) diff --git a/evaluation/AndroidWorld/android_world/env/setup_device/setup_test.py b/evaluation/AndroidWorld/android_world/env/setup_device/setup_test.py new file mode 100644 index 0000000..7aff779 --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/setup_device/setup_test.py @@ -0,0 +1,113 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from android_env.components import errors +from android_world.env import adb_utils +from android_world.env import interface +from android_world.env import tools +from android_world.env.setup_device import apps +from android_world.env.setup_device import setup +from android_world.utils import app_snapshot + + +class SetupTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.mock_issue_generic_request = self.enter_context( + mock.patch.object(adb_utils, "issue_generic_request") + ) + + @mock.patch.object(tools, "AndroidToolController") + @mock.patch.object(setup, "download_and_install_apk") + @mock.patch.object(app_snapshot, "save_snapshot") + def test_setup_apps(self, mock_save_snapshot, mock_install_apk, unused_tools): + env = mock.create_autospec(interface.AsyncEnv) + mock_app_setups = { + app_class: mock.patch.object(app_class, "setup").start() + for app_class in setup._APPS + } + + setup.setup_apps(env) + + for app_class in setup._APPS: + if app_class.apk_names: # 1P apps do not have APKs. + mock_install_apk.assert_any_call( + app_class.apk_names[0], env.controller.env + ) + mock_app_setups[app_class].assert_any_call(env) + mock_save_snapshot.assert_any_call(app_class.app_name, env.controller) + + +class _App(apps.AppSetup): + + def __init__(self, apk_names, app_name): + self.apk_names = apk_names + self.app_name = app_name + + +class InstallApksTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.env = mock.create_autospec(interface.AsyncEnv) + self.mockdownload_and_install_apk = self.enter_context( + mock.patch.object(setup, "download_and_install_apk") + ) + self.apps = [ + _App(apk_names=["apk1", "apk2"], app_name="App1"), + _App(apk_names=[], app_name="App2"), # No APKs + _App(apk_names=["apk3"], app_name="App3"), + ] + setup._APPS = self.apps + + def test_install_all_apks_success(self): + self.mockdownload_and_install_apk.return_value = None + + for app in self.apps: + setup.maybe_install_app(app, self.env) + + expected_calls = [ + mock.call("apk1", self.env.controller.env), + mock.call("apk3", self.env.controller.env), + ] + self.mockdownload_and_install_apk.assert_has_calls( + expected_calls, any_order=True + ) + + def test_install_all_apks_success_with_fallback(self): + def side_effect(apk_name, env): + del env + if apk_name == "apk1": + raise errors.AdbControllerError + return None + + self.mockdownload_and_install_apk.side_effect = side_effect + + for app in self.apps: + setup.maybe_install_app(app, self.env) + + expected_calls = [ + mock.call("apk1", self.env.controller.env), + mock.call("apk2", self.env.controller.env), + mock.call("apk3", self.env.controller.env), + ] + self.mockdownload_and_install_apk.assert_has_calls(expected_calls) + + +if __name__ == "__main__": + absltest.main() diff --git a/evaluation/AndroidWorld/android_world/env/tools.py b/evaluation/AndroidWorld/android_world/env/tools.py new file mode 100644 index 0000000..17394ec --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/tools.py @@ -0,0 +1,190 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""API tools library for Android agents.""" + +import inspect +import json +import time +from typing import Optional, Union + +from android_world.env import actuation +from android_world.env import adb_utils +from android_world.env import android_world_controller +from android_world.utils import contacts_utils + + +# When the compose message is pulled up, the send button has this as text for +# Simple SMS Messenger. +SIMPLE_SMS_SEND_TEXT = "SMS" +# For Google messaging app. +SMS_SEND_TEXT = "Send SMS" + + +class AndroidToolController: + """Executes API tools on an Android device.""" + + def __init__( + self, + env: android_world_controller.AndroidWorldController, + ): + """Initializes the controller with an Android environment instance. + + Args: + env: The AndroidEnv interface to be used. + """ + self._env = env + + def click_element(self, element_text: str): + actuation.find_and_click_element(element_text, self._env) + + def open_web_page(self, url: str): + """Open a web page in the default browser on an Android device. + + This function sends an intent to the Android system to open the specified + URL. + + Args: + url: The URL of the web page to open. E.g., http://www.google.com. + """ + if not url.startswith("http://"): + url = "http://" + url + adb_command = ["shell", f"am start -a android.intent.action.VIEW -d {url}"] + adb_utils.issue_generic_request(adb_command, self._env) + + def send_sms( + self, + phone_number: str, + message: str, + ): + """Send an SMS to a specified phone number. + + This function sends an intent to the Android system to open the messaging + app with the recipient's number and message pre-filled. + + Args: + phone_number: The phone number to which the SMS should be sent. + message: The pre-filled message text. + """ + # Construct the Intent command + intent_command = ( + "am start -a android.intent.action.SENDTO -d sms:{phone_number} " + f'--es sms_body "{message}"' + ).format(phone_number=phone_number) + + adb_command = ["shell", intent_command] + adb_utils.issue_generic_request(adb_command, self._env) + time.sleep(5.0) + + package_name = adb_utils.extract_package_name( + adb_utils.get_current_activity(self._env)[0] + ) + # Depending on what the default SMS app we need to click different buttons. + if package_name == "com.google.android.apps.messaging": + self.click_element(SMS_SEND_TEXT) + elif package_name == "com.simplemobiletools.smsmessenger": + self.click_element(SIMPLE_SMS_SEND_TEXT) + else: + raise ValueError(f"Messaging app not supported: {package_name}") + + def _gather_tool_details( + self, + ) -> dict[str, list[Optional[dict[str, Union[dict[str, str], str]]]]]: + """Get the details and examples of usage for public APIs related to Android tools. + + Returns: + A dictionary where the keys are API names and the values are lists of + dictionaries containing the docstrings and usage examples. + """ + return { + "open_web_page": self._tool_info( + self.open_web_page, + [ + {"url": "http://www.google.com"}, + {"url": "http://www.example.com"}, + ], + ), + "send_sms": self._tool_info( + self.send_sms, + [ + { + "phone_number": "+123456789", + "message": "Hello, how are you?", + }, + { + "phone_number": "+987654321", + "message": "Meeting rescheduled to 3 PM.", + }, + ], + ), + "add_contact": self._tool_info( + contacts_utils.add_contact, + [ + {"name": "John Doe", "phone_number": "+123456789"}, + {"name": "Joe", "phone_number": "987654321"}, + ], + ), + } + + def _tool_info( + self, method, example_args: list[dict[str, str]] + ) -> list[Optional[dict[str, Union[dict[str, str], str]]]]: + """Helper function to construct tool information and examples. + + Args: + method: The method for which to gather information. + example_args: A list of argument dictionaries for examples. + + Returns: + A list containing the method's documentation and examples. + """ + doc_info = {"doc": inspect.getdoc(method)} + examples = [ + {"method": method.__name__, "args": args} for args in example_args + ] + return [doc_info, *examples] + + def display_tool_usage(self) -> str: + """Format the tool information and examples into a user-friendly string. + + Returns: + A string representing the available tools and their usage examples. + """ + tools_info = self._gather_tool_details() + formatted_info = ["Available Tools and Usage Examples:\n"] + + for tool_name, tool_details in tools_info.items(): + formatted_info.append(f"\nAPI: {tool_name}\n") + formatted_info.append(f"Description: {tool_details[0]['doc']}\n") + formatted_info.append("Examples:\n") + for example in tool_details[1:]: + formatted_info.append(f" - JSON Request: {example}\n") + + return "".join(formatted_info) + + def handle_json_request(self, json_request: str): + """Handle a JSON formatted request to use a tool. + + Args: + json_request: A JSON string with the method and arguments. + """ + request = json.loads(json_request) + method_name = request["method"] + args = request.get("args", {}) + + if hasattr(self, method_name) and callable(getattr(self, method_name)): + method = getattr(self, method_name) + method(**args) + else: + raise ValueError(f"Method {method_name} not found.") diff --git a/evaluation/AndroidWorld/android_world/env/tools_test.py b/evaluation/AndroidWorld/android_world/env/tools_test.py new file mode 100644 index 0000000..d449282 --- /dev/null +++ b/evaluation/AndroidWorld/android_world/env/tools_test.py @@ -0,0 +1,155 @@ +# Copyright 2024 The android_world Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from android_env import env_interface +from android_world.env import adb_utils +from android_world.env import tools + + +@mock.patch.object(adb_utils, "issue_generic_request") +class TestAndroidToolControllerOpenWebPage(absltest.TestCase): + + def test_open_web_page_with_http_prefix(self, mock_generic_request): + """Test opening a web page with http prefix.""" + mock_env = mock.create_autospec(env_interface.AndroidEnvInterface) + + controller = tools.AndroidToolController(mock_env) + controller.open_web_page("http://www.example.com") + + mock_generic_request.assert_called_once_with( + [ + "shell", + "am start -a android.intent.action.VIEW -d http://www.example.com", + ], + mock_env, + ) + + def test_open_web_page_without_http_prefix(self, mock_generic_request): + """Test opening a web page without http prefix.""" + mock_env = mock.create_autospec(env_interface.AndroidEnvInterface) + + controller = tools.AndroidToolController(mock_env) + controller.open_web_page("www.example.com") + + mock_generic_request.assert_called_once_with( + [ + "shell", + "am start -a android.intent.action.VIEW -d http://www.example.com", + ], + mock_env, + ) + + +class TestAndroidToolControllerSendSmsIntent(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name="GoogleSMS", + activity_name="com.google.android.apps.messaging/ActivityName", + ui_element="Send SMS", + ), + dict( + testcase_name="SimpleSMS", + activity_name="com.simplemobiletools.smsmessenger/ActivityName", + ui_element="SMS", + ), + ) + @mock.patch.object(adb_utils, "issue_generic_request") + @mock.patch.object(tools.AndroidToolController, "click_element") + @mock.patch.object(adb_utils, "get_current_activity") + def test_send_sms( + self, + mock_get_current_activity, + mock_click_element, + mock_generic_request, + activity_name, + ui_element, + ): + mock_env = mock.create_autospec(env_interface.AndroidEnvInterface) + controller = tools.AndroidToolController(mock_env) + mock_get_current_activity.return_value = ( + activity_name, + None, + ) + phone_number = "+123456789" + message = "Hello, how are you?" + + controller.send_sms(phone_number, message) + + expected_adb_command = [ + "shell", + ( + "am start -a android.intent.action.SENDTO -d" + f' sms:{phone_number} --es sms_body "{message}"' + ), + ] + mock_generic_request.assert_called_once_with(expected_adb_command, mock_env) + mock_click_element.assert_called_once_with(ui_element) + + +@mock.patch.object(tools.AndroidToolController, "open_web_page", autospec=True) +@mock.patch.object(tools.AndroidToolController, "send_sms", autospec=True) +class TestAndroidToolControllerHandleJsonRequest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.mock_env = mock.create_autospec(env_interface.AndroidEnvInterface) + self.controller = tools.AndroidToolController(self.mock_env) + + def test_handle_json_request_valid_method_open_web_page( + self, unused_mock_send_sms, mock_open_web + ): + json_request = json.dumps( + {"method": "open_web_page", "args": {"url": "http://www.example.com"}} + ) + + self.controller.handle_json_request(json_request) + + mock_open_web.assert_called_once_with( + self.controller, url="http://www.example.com" + ) + + def test_handle_json_request_valid_method_send_sms_intent( + self, mock_send_sms, unused_mock_open_web + ): + """Test handling a valid JSON request for sending an SMS intent.""" + json_request = json.dumps({ + "method": "send_sms", + "args": {"phone_number": "+123456789", "message": "Hello"}, + }) + + self.controller.handle_json_request(json_request) + + mock_send_sms.assert_called_once_with( + self.controller, phone_number="+123456789", message="Hello" + ) + + def test_handle_json_request_invalid_method( + self, + unused_mock_send_sms, + unused_mock_open_web, + ): + """Test handling a JSON request with an invalid method.""" + json_request = json.dumps({"method": "non_existent_method", "args": {}}) + with self.assertRaises(ValueError): + self.controller.handle_json_request(json_request) + + +if __name__ == "__main__": + absltest.main() diff --git a/evaluation/AndroidWorld/run.py b/evaluation/AndroidWorld/run.py index aa90408..e71a07a 100644 --- a/evaluation/AndroidWorld/run.py +++ b/evaluation/AndroidWorld/run.py @@ -20,8 +20,20 @@ command-line flags. """ -from collections.abc import Sequence import os + +# ---- Reduce noisy gRPC C++/absl logs (must be set BEFORE importing grpc/android_world) ---- +# These messages often look like: +# I0000 ... fork_posix.cc:71] Other threads are currently calling into gRPC, skipping fork() handlers +# They are usually harmless, but very noisy. +os.environ.setdefault("GRPC_VERBOSITY", "ERROR") +os.environ.setdefault("GRPC_TRACE", "none") +# gRPC uses Abseil logging in C++ in many builds; this env can help reduce INFO logs. +os.environ.setdefault("ABSL_MIN_LOG_LEVEL", "2") # 0=INFO,1=WARNING,2=ERROR,3=FATAL +# Some environments still honor glog. +os.environ.setdefault("GLOG_minloglevel", "2") + +from collections.abc import Sequence import concurrent.futures from absl import app from absl import flags @@ -42,9 +54,6 @@ logging.set_verbosity(logging.WARNING) -os.environ["GRPC_VERBOSITY"] = "ERROR" # Only show errors -os.environ["GRPC_TRACE"] = "none" # Disable tracing - def _find_adb_directory() -> str: """Returns the directory where adb is located.""" @@ -135,6 +144,23 @@ def _find_adb_directory() -> str: # Agent specific. _AGENT_NAME = flags.DEFINE_string("agent_name", "seeact_v", help="Agent name.") +# Qwen3VL (OpenAI-compatible server) specific. +_QWEN3VL_MODEL_BASE_URL = flags.DEFINE_string( + "qwen3vl_model_base_url", + "http://127.0.0.1:8000/v1", + "Qwen3VL OpenAI-compatible base_url, e.g. http://host:port/v1", +) +_QWEN3VL_MODEL_API_KEY = flags.DEFINE_string( + "qwen3vl_model_api_key", + "EMPTY", + "Qwen3VL API key for OpenAI-compatible server (if needed).", +) +_QWEN3VL_MODEL_NAME = flags.DEFINE_string( + "qwen3vl_model_name", + "", + "Model name passed to /v1/chat/completions (depends on your server).", +) + _FIXED_TASK_SEED = flags.DEFINE_boolean( "fixed_task_seed", False, @@ -207,6 +233,14 @@ def _get_agent( model_address=" http://10.140.66.139:10026/", mode="Agent", ) + elif _AGENT_NAME.value == "qwen3vl": + agent = seeact_v.Qwen3VL( + env, + infer.Gpt4Wrapper("gpt-4o"), + model_base_url=_QWEN3VL_MODEL_BASE_URL.value, + model_api_key=_QWEN3VL_MODEL_API_KEY.value, + model_name=_QWEN3VL_MODEL_NAME.value, + ) if not agent: raise ValueError(f"Unknown agent: {_AGENT_NAME.value}")