|
18 | 18 | import time |
19 | 19 |
|
20 | 20 | import ast |
| 21 | +import json |
21 | 22 | import numpy as np |
22 | 23 | from PIL import Image |
23 | 24 | from openai import OpenAI |
|
40 | 41 | from android_world.env import interface |
41 | 42 | from android_world.env import json_action |
42 | 43 | from android_world.env import representation_utils |
43 | | -from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize |
| 44 | + |
| 45 | +try: |
| 46 | + from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import ( # type: ignore |
| 47 | + smart_resize, |
| 48 | + ) |
| 49 | +except Exception: # pragma: no cover |
| 50 | + smart_resize = None # type: ignore[assignment] |
44 | 51 |
|
45 | 52 | # Utils for Visual Grounding |
46 | 53 |
|
@@ -932,3 +939,184 @@ def _to_base64_png(image: np.ndarray) -> str: |
932 | 939 | buf = BytesIO() |
933 | 940 | PILImage.fromarray(image).save(buf, format="PNG") |
934 | 941 | return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" |
| 942 | + |
| 943 | + |
| 944 | +def _extract_action_text_qwen3vl(block: str) -> str: |
| 945 | + """Extracts the 'Action:' line from Qwen3VL text output for step history (does not affect execution).""" |
| 946 | + m = re.search(r"Action:\s*(.+?)(?:\n<tool_call>|$)", block, flags=re.S) |
| 947 | + if not m: |
| 948 | + return "" |
| 949 | + text = m.group(1).strip() |
| 950 | + # Some models wrap Action: "..." with quotes. |
| 951 | + if text.startswith('"') and text.endswith('"'): |
| 952 | + text = text[1:-1] |
| 953 | + return text.replace("\n", " ") |
| 954 | + |
| 955 | + |
| 956 | +def _parse_tool_call_json(block: str) -> dict[str, Any] | None: |
| 957 | + """Parse JSON inside <tool_call>...</tool_call>.""" |
| 958 | + m = re.search(r"<tool_call>\s*([\s\S]*?)\s*</tool_call>", block) |
| 959 | + if not m: |
| 960 | + return None |
| 961 | + payload = m.group(1).strip() |
| 962 | + try: |
| 963 | + return json.loads(payload) |
| 964 | + except Exception: |
| 965 | + return None |
| 966 | + |
| 967 | + |
| 968 | +class Qwen3VL(base_agent.EnvironmentInteractingAgent): |
| 969 | + """Android GUI Agent based on Qwen3VL tool-call output (for AndroidWorld eval). |
| 970 | +
|
| 971 | + - Input: Screenshot + instruction + history |
| 972 | + - Output: <tool_call>{...}</tool_call> |
| 973 | + - Execution: Map to JSONAction by qwen3vl_action_transform(...) |
| 974 | + """ |
| 975 | + |
| 976 | + def __init__( |
| 977 | + self, |
| 978 | + env: interface.AsyncEnv, |
| 979 | + llm: infer.MultimodalLlmWrapper, |
| 980 | + name: str = "Qwen3VL", |
| 981 | + wait_after_action_seconds: float = 2.0, |
| 982 | + model_base_url: str = "http://127.0.0.1:8000/v1", |
| 983 | + model_api_key: str = "EMPTY", |
| 984 | + model_name: str = "", |
| 985 | + extra_headers: dict[str, str] | None = None, |
| 986 | + ): |
| 987 | + super().__init__(env, name) |
| 988 | + self.llm = llm |
| 989 | + self.wait_after_action_seconds = wait_after_action_seconds |
| 990 | + self.model_name = model_name |
| 991 | + self.client = OpenAI( |
| 992 | + api_key=model_api_key, |
| 993 | + base_url=model_base_url, |
| 994 | + default_headers=extra_headers, |
| 995 | + ) |
| 996 | + self.step_his: str = "" |
| 997 | + self.turn_number: int = 0 |
| 998 | + # Used to detect repeated actions (avoid infinite loops) |
| 999 | + self.last_action: str | None = None |
| 1000 | + self.repeat_time: int = 0 |
| 1001 | + |
| 1002 | + def reset(self, go_home_on_reset: bool = False): |
| 1003 | + super().reset(go_home_on_reset) |
| 1004 | + self.env.hide_automation_ui() |
| 1005 | + self.step_his = "" |
| 1006 | + self.turn_number = 0 |
| 1007 | + self.last_action = None |
| 1008 | + self.repeat_time = 0 |
| 1009 | + |
| 1010 | + @staticmethod |
| 1011 | + def _to_base64_png(image: np.ndarray) -> str: |
| 1012 | + import base64 |
| 1013 | + from io import BytesIO |
| 1014 | + from PIL import Image as PILImage |
| 1015 | + buf = BytesIO() |
| 1016 | + PILImage.fromarray(image).save(buf, format='PNG') |
| 1017 | + return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" |
| 1018 | + |
| 1019 | + def step(self, instruction: str) -> base_agent.AgentInteractionResult: |
| 1020 | + self.turn_number += 1 |
| 1021 | + |
| 1022 | + state = self.get_post_transition_state() |
| 1023 | + screenshot = state.pixels.copy() |
| 1024 | + # To be consistent with other agents in this file: BGR->RGB (for saving/encoding) |
| 1025 | + screenshot = screenshot[:, :, ::-1] |
| 1026 | + height, width = screenshot.shape[:2] |
| 1027 | + |
| 1028 | + system_prompt = QWEN3VL_SYSTEM_PROMPT |
| 1029 | + user_prompt = QWEN3VL_USER_PROMPT.format( |
| 1030 | + instruction=instruction, history=self.step_his |
| 1031 | + ) |
| 1032 | + |
| 1033 | + messages = [ |
| 1034 | + { |
| 1035 | + "role": "system", |
| 1036 | + "content": [{"type": "text", "text": system_prompt}], |
| 1037 | + }, |
| 1038 | + { |
| 1039 | + "role": "user", |
| 1040 | + "content": [ |
| 1041 | + {"type": "text", "text": user_prompt}, |
| 1042 | + {"type": "image_url", "image_url": {"url": self._to_base64_png(screenshot)}}, |
| 1043 | + ], |
| 1044 | + }, |
| 1045 | + ] |
| 1046 | + |
| 1047 | + completion = self.client.chat.completions.create( |
| 1048 | + model=self.model_name, |
| 1049 | + messages=messages, |
| 1050 | + temperature=0, |
| 1051 | + ) |
| 1052 | + response = completion.choices[0].message.content or "" |
| 1053 | + print(response) |
| 1054 | + print("=" * 50) |
| 1055 | + |
| 1056 | + tool_call = _parse_tool_call_json(response) |
| 1057 | + if not tool_call: |
| 1058 | + return base_agent.AgentInteractionResult( |
| 1059 | + True, {"summary": "No <tool_call> JSON found in model output.", "response": response} |
| 1060 | + ) |
| 1061 | + |
| 1062 | + op_text = _extract_action_text_qwen3vl(response) |
| 1063 | + if op_text: |
| 1064 | + self.step_his += f"Step {self.turn_number}: {op_text}; " |
| 1065 | + |
| 1066 | + # Compatible: tool_call may look like {"name":"mobile_use","arguments":{...}} |
| 1067 | + args = tool_call.get("arguments", {}) if isinstance(tool_call, dict) else {} |
| 1068 | + action_name = args.get("action", "") |
| 1069 | + try: |
| 1070 | + parsed = qwen3vl_action_transform(action_name, args, width, height) |
| 1071 | + print(parsed) |
| 1072 | + except Exception as e: |
| 1073 | + return base_agent.AgentInteractionResult( |
| 1074 | + True, |
| 1075 | + { |
| 1076 | + "summary": f"Failed to transform tool-call into action: {e}", |
| 1077 | + "response": response, |
| 1078 | + "tool_call": tool_call, |
| 1079 | + }, |
| 1080 | + ) |
| 1081 | + |
| 1082 | + # Record last_action + repeat_time (previous code had these fields but not working) |
| 1083 | + # Here, use the tool-call's arguments as the "action signature", which is more robust than checking 'terminate' in a string. |
| 1084 | + try: |
| 1085 | + action_sig = json.dumps(args, ensure_ascii=False, sort_keys=True) |
| 1086 | + except Exception: |
| 1087 | + action_sig = str(args) |
| 1088 | + if self.last_action == action_sig: |
| 1089 | + self.repeat_time += 1 |
| 1090 | + else: |
| 1091 | + self.repeat_time = 0 |
| 1092 | + self.last_action = action_sig |
| 1093 | + |
| 1094 | + try: |
| 1095 | + act = json_action.JSONAction(**parsed) |
| 1096 | + self.env.execute_action(act) |
| 1097 | + time.sleep(self.wait_after_action_seconds) |
| 1098 | + except Exception: |
| 1099 | + # continue |
| 1100 | + print("Failed to execute action:", parsed) |
| 1101 | + |
| 1102 | + if parsed.get("action_type") == "status": |
| 1103 | + return base_agent.AgentInteractionResult( |
| 1104 | + True, {"response": response, "step_history": self.step_his, "parsed": parsed} |
| 1105 | + ) |
| 1106 | + |
| 1107 | + # If repeated actions reach the threshold: terminate immediately to avoid deadlock in evaluation |
| 1108 | + if self.repeat_time >= 3: |
| 1109 | + return base_agent.AgentInteractionResult( |
| 1110 | + True, |
| 1111 | + { |
| 1112 | + "summary": "Terminated due to repeated identical actions.", |
| 1113 | + "response": response, |
| 1114 | + "step_history": self.step_his, |
| 1115 | + "parsed": parsed, |
| 1116 | + "repeat_time": self.repeat_time, |
| 1117 | + }, |
| 1118 | + ) |
| 1119 | + |
| 1120 | + return base_agent.AgentInteractionResult( |
| 1121 | + False, {"response": response, "step_history": self.step_his, "parsed": parsed} |
| 1122 | + ) |
0 commit comments