Skip to content

screenspot-v2评测分数很低 #98

@gaolongxi

Description

@gaolongxi

Hi~
我尝试在两个GUI Grounding的benchmark:ScreenSpot-v2和ScreenSpot上面测试AgentCPM-GUI的表现,使用的是fun_2_bbox.py中的prompt,目前测出来分数很低,screenspot-v2只有59%左右正确率,qwen等模型已经已经在90%左右。我不知道是否是我的推理代码有问题,请问官方在这两个benchmark上面测试过吗?

我的推理代码是:

import os
import re
import io
import json
import base64
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


class AgentCPM_GUI():

    def load_model(self, model_name_or_path="model/AgentCPM-GUI"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            trust_remote_code=True,
            dtype=torch.bfloat16,
            attn_implementation="flash_attention_2"
        ).to("cuda").eval()
        self.generation_params = {'do_sample': False, 'temperature': 0.0, 'use_cache': True, 'max_new_tokens': 2048}

        # --- system prompt ---
        self.sys_prompt = '''
你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的坐标。你的下一步操作是根据给定的GUI截图和图中某个组件的功能描述点击组件的中心位置。坐标为相对于屏幕左上角位原点的相对位置,并且按照宽高比例缩放到0~1000
输入:屏幕截图,功能描述
输出:点击操作,以{\"POINT\":[...,...]}为格式,其中不能存在任何非坐标字符

# Rule
- 输出操作必须遵循Schema约束

# Schema
{
    "required": ["thought"]
}
'''

    def _resize(self, origin_img):
        resolution = origin_img.size
        w,h = resolution
        max_line_res = 1120
        if max_line_res is not None:
            max_line = max_line_res
            if h > max_line:
                w = int(w * max_line / h)
                h = max_line
            if w > max_line:
                h = int(h * max_line / w)
                w = max_line
        img = origin_img.resize((w,h),resample=Image.Resampling.LANCZOS)
        return img

    def set_generation_config(self, **kwargs):
        self.generation_params.update(**kwargs)

    def inference(self, instruction, image_path):

        assert os.path.exists(image_path) and os.path.isfile(image_path), "Invalid input image path."

        image = Image.open(image_path)
        image = self._resize(image)

        messages = [{
            "role": "user",
            "content": [
                f"屏幕上某一组件的功能描述:{instruction}\n当前屏幕截图:",
                image
            ]
        }]
        print("instruction:", instruction)

        output_text = self.model.chat(
            image=None,
            system_prompt=self.sys_prompt,
            msgs=messages,
            tokenizer=self.tokenizer,
            temperature=0.1,
            # do_sample=False
        )
        print("Raw response:", output_text)

        x_rel = y_rel = 0.0
        try:
            match = re.search(r'"POINT"\s*:\s*\[\s*(\d+(?:\.\d+)?),\s*(\d+(?:\.\d+)?)\s*\]', output_text)
            if match:
                x_rel = float(match.group(1))
                y_rel = float(match.group(2))
        except:
            print("Warning: No POINT parsed.")

        x_norm = x_rel / 1000.0
        y_norm = y_rel / 1000.0

        result = {
            "result": "positive",
            "format": "x1y1x2y2",
            "raw_response": output_text,
            "bbox": None,
            "point": [x_norm, y_norm],
        }

        return result

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions