-
Notifications
You must be signed in to change notification settings - Fork 105
Open
Description
Hi~
我尝试在两个GUI Grounding的benchmark:ScreenSpot-v2和ScreenSpot上面测试AgentCPM-GUI的表现,使用的是fun_2_bbox.py中的prompt,目前测出来分数很低,screenspot-v2只有59%左右正确率,qwen等模型已经已经在90%左右。我不知道是否是我的推理代码有问题,请问官方在这两个benchmark上面测试过吗?
我的推理代码是:
import os
import re
import io
import json
import base64
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class AgentCPM_GUI():
def load_model(self, model_name_or_path="model/AgentCPM-GUI"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=True,
dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
).to("cuda").eval()
self.generation_params = {'do_sample': False, 'temperature': 0.0, 'use_cache': True, 'max_new_tokens': 2048}
# --- system prompt ---
self.sys_prompt = '''
你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的坐标。你的下一步操作是根据给定的GUI截图和图中某个组件的功能描述点击组件的中心位置。坐标为相对于屏幕左上角位原点的相对位置,并且按照宽高比例缩放到0~1000
输入:屏幕截图,功能描述
输出:点击操作,以{\"POINT\":[...,...]}为格式,其中不能存在任何非坐标字符
# Rule
- 输出操作必须遵循Schema约束
# Schema
{
"required": ["thought"]
}
'''
def _resize(self, origin_img):
resolution = origin_img.size
w,h = resolution
max_line_res = 1120
if max_line_res is not None:
max_line = max_line_res
if h > max_line:
w = int(w * max_line / h)
h = max_line
if w > max_line:
h = int(h * max_line / w)
w = max_line
img = origin_img.resize((w,h),resample=Image.Resampling.LANCZOS)
return img
def set_generation_config(self, **kwargs):
self.generation_params.update(**kwargs)
def inference(self, instruction, image_path):
assert os.path.exists(image_path) and os.path.isfile(image_path), "Invalid input image path."
image = Image.open(image_path)
image = self._resize(image)
messages = [{
"role": "user",
"content": [
f"屏幕上某一组件的功能描述:{instruction}\n当前屏幕截图:",
image
]
}]
print("instruction:", instruction)
output_text = self.model.chat(
image=None,
system_prompt=self.sys_prompt,
msgs=messages,
tokenizer=self.tokenizer,
temperature=0.1,
# do_sample=False
)
print("Raw response:", output_text)
x_rel = y_rel = 0.0
try:
match = re.search(r'"POINT"\s*:\s*\[\s*(\d+(?:\.\d+)?),\s*(\d+(?:\.\d+)?)\s*\]', output_text)
if match:
x_rel = float(match.group(1))
y_rel = float(match.group(2))
except:
print("Warning: No POINT parsed.")
x_norm = x_rel / 1000.0
y_norm = y_rel / 1000.0
result = {
"result": "positive",
"format": "x1y1x2y2",
"raw_response": output_text,
"bbox": None,
"point": [x_norm, y_norm],
}
return result
Metadata
Metadata
Assignees
Labels
No labels