Skip to content

Commit cd1d194

Browse files
authored
[Task] V*-Bench (Visual Star Benchmark) (#683)
* Update project configuration and scripts - Updated Python version requirement in pyproject.toml to >=3.12. - Removed specific version constraint for protobuf in dependencies. - Added 'uv.lock' to .gitignore. - Modified example script to change model task from 'mmmu_pro' to 'mme' and updated comments for clarity. * Update example script comments for clarity on visual token positioning * Add .venv to .gitignore to exclude virtual environment files * Refactor Qwen2_5_VL model initialization and add V* benchmark task files - Changed `use_flash_attention_2` parameter to `attn_implementation` for better flexibility in attention methods. - Added validation for `attn_implementation` to ensure only valid options are accepted. - Updated model loading to dynamically include attention implementation in arguments. - Introduced new V* benchmark task files, including default template, utility functions, and specific task configurations for direct attributes and relative position metrics. * Add VLMs Are Blind benchmark task files - Introduced the VLMs Are Blind benchmark to evaluate visual reasoning capabilities of Vision-Language Models through path-counting tasks in subway diagrams. - Added task configuration files, utility functions, and a comprehensive README for task instructions and dataset details. - Implemented both standard and lite versions of the benchmark for varied evaluation speeds.
1 parent 026aa05 commit cd1d194

14 files changed

+515
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ lmms_eval/tasks/mlvu/__pycache__/utils.cpython-310.pyc
4444

4545
scripts/
4646
.env
47+
.venv
4748
outputs/
4849
span.log
4950
uv.lock

lmms_eval/models/qwen2_5_vl.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
device_map: Optional[str] = "auto",
4343
batch_size: Optional[Union[int, str]] = 1,
4444
use_cache=True,
45-
use_flash_attention_2: Optional[bool] = False,
45+
attn_implementation: Optional[str] = None,
4646
min_pixels: int = 256 * 28 * 28,
4747
max_pixels: int = 1605632,
4848
max_num_frames: int = 32,
@@ -58,6 +58,11 @@ def __init__(
5858
# Do not use kwargs for now
5959
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
6060

61+
# Validate attention implementation
62+
valid_attn_implementations = [None, "flash_attention_2", "sdpa", "eager"]
63+
if attn_implementation not in valid_attn_implementations:
64+
raise ValueError(f"attn_implementation must be one of {valid_attn_implementations}, got {attn_implementation}")
65+
6166
self.use_custom_video_loader = use_custom_video_loader
6267
self.fps = fps
6368
# if self.fps and not self.use_custom_video_loader:
@@ -74,15 +79,17 @@ def __init__(
7479
self._device = torch.device(device)
7580
self.device_map = device_map if device_map else device
7681

77-
if use_flash_attention_2:
78-
self._model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
79-
pretrained,
80-
torch_dtype=torch.bfloat16,
81-
device_map=self.device_map,
82-
attn_implementation="flash_attention_2",
83-
).eval()
84-
else:
85-
self._model = Qwen2_5_VLForConditionalGeneration.from_pretrained(pretrained, torch_dtype="auto", device_map=self.device_map).eval()
82+
# Prepare model loading arguments
83+
model_kwargs = {
84+
"torch_dtype": "auto",
85+
"device_map": self.device_map,
86+
}
87+
88+
# Add attention implementation if specified
89+
if attn_implementation is not None:
90+
model_kwargs["attn_implementation"] = attn_implementation
91+
92+
self._model = Qwen2_5_VLForConditionalGeneration.from_pretrained(pretrained, **model_kwargs).eval()
8693
self.max_pixels = max_pixels
8794
self.min_pixels = min_pixels
8895
self.max_num_frames = max_num_frames
@@ -296,14 +303,20 @@ def _collate(x):
296303
}
297304
# Update with provided kwargs
298305
current_gen_kwargs = {**default_gen_kwargs, **gen_kwargs}
299-
300306
pad_token_id = self.tokenizer.pad_token_id
301307

308+
if current_gen_kwargs["temperature"] > 0:
309+
current_gen_kwargs["do_sample"] = True
310+
else:
311+
current_gen_kwargs["do_sample"] = False
312+
current_gen_kwargs["temperature"] = None
313+
current_gen_kwargs["top_p"] = None
314+
302315
cont = self.model.generate(
303316
**inputs,
304317
eos_token_id=self.tokenizer.eos_token_id,
305318
pad_token_id=pad_token_id,
306-
do_sample=True if current_gen_kwargs["temperature"] > 0 else False,
319+
do_sample=current_gen_kwargs["do_sample"],
307320
temperature=current_gen_kwargs["temperature"],
308321
top_p=current_gen_kwargs["top_p"],
309322
num_beams=current_gen_kwargs["num_beams"],
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# VLMs Are Blind
2+
3+
## Overview
4+
5+
VLMs Are Blind is a benchmark designed to test the visual reasoning capabilities of Vision-Language Models (VLMs) through path-counting tasks in subway connection diagrams. The benchmark reveals fundamental limitations in VLMs' ability to process visual information, showing that many models struggle with basic visual tasks that are trivial for humans.
6+
7+
## Paper Information
8+
9+
- **Paper Title**: VLMs Are Blind
10+
- **Paper**: https://arxiv.org/abs/2407.06581
11+
- **GitHub**: https://github.com/xai-org/vlmsareblind
12+
- **Dataset**: https://huggingface.co/datasets/XAI/vlmsareblind
13+
14+
## Dataset Details
15+
16+
The benchmark consists of path-counting tasks where models must count the number of paths between two stations in subway-style connection diagrams. Each instance contains:
17+
- A subway map diagram image
18+
- A question asking for the number of paths between two specified stations
19+
- The correct answer in the format {N}
20+
21+
## Task Configuration
22+
23+
### Main Task
24+
- **Task Name**: `vlmsareblind`
25+
- **Split**: `valid`
26+
- **Output Type**: `generate_until`
27+
- **Metric**: Exact match accuracy
28+
29+
### Lite Version
30+
- **Task Name**: `vlmsareblind_lite`
31+
- A subset version for faster evaluation
32+
33+
## Evaluation
34+
35+
The benchmark uses exact match accuracy as the primary metric. Models must output their answer in the format `{N}` where N is the number of paths.
36+
37+
### Metrics
38+
- **exact_match**: Binary score for each instance (1 if prediction matches ground truth, 0 otherwise)
39+
- **Aggregation**: Mean across all instances
40+
41+
## Running the Benchmark
42+
43+
```bash
44+
# Run the full benchmark
45+
lmms-eval --model <model_name> --tasks vlmsareblind --batch_size 1
46+
47+
# Run the lite version
48+
lmms-eval --model <model_name> --tasks vlmsareblind_lite --batch_size 1
49+
```
50+
51+
## Implementation Details
52+
53+
### Generation Configuration
54+
- **max_new_tokens**: 32
55+
- **temperature**: 0
56+
- **top_p**: 1.0
57+
- **num_beams**: 1
58+
- **do_sample**: false
59+
60+
### Answer Extraction
61+
The evaluation expects answers in the format `{N}`. The answer extraction logic:
62+
1. First looks for numbers within curly brackets: `{3}`
63+
2. If not found, looks for standalone numbers and adds brackets
64+
3. Falls back to the raw response if no pattern matches
65+
66+
### Prompt Format
67+
```
68+
[Image of subway diagram]
69+
[Question about counting paths]
70+
Answer with a number in curly brackets, e.g., {3}.
71+
```
72+
73+
## File Structure
74+
```
75+
vlmsareblind/
76+
├── README.md # This file
77+
├── utils.py # Evaluation utilities
78+
├── vlmsareblind.yaml # Main task configuration
79+
└── vlmsareblind_lite.yaml # Lite version configuration
80+
```
81+
82+
## Citation
83+
84+
```bibtex
85+
@article{vlmsareblind2024,
86+
title={VLMs Are Blind},
87+
author={XAI Team},
88+
journal={arXiv preprint arXiv:2407.06581},
89+
year={2024}
90+
}
91+
```
92+
93+
## Notes
94+
95+
- The benchmark is designed to test fundamental visual reasoning capabilities
96+
- Results often reveal significant gaps between VLM performance and human abilities
97+
- The simple counting task format makes it easy to verify model outputs
98+
- Temperature is set to 0 for deterministic outputs
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# VLMs Are Blind benchmark task
2+
# Tests visual reasoning capabilities through path-counting in subway connection diagrams

lmms_eval/tasks/vlmsareblind/utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import re
2+
from typing import Dict, List
3+
4+
5+
def vlmsareblind_doc_to_visual(doc: Dict) -> List:
6+
"""Extract image from the document."""
7+
if "image" in doc:
8+
return [doc["image"]]
9+
return []
10+
11+
12+
def vlmsareblind_doc_to_text(doc: Dict, lmms_eval_specific_kwargs: Dict) -> str:
13+
"""Format the prompt for the model."""
14+
prompt = doc.get("prompt", "")
15+
pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "")
16+
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "")
17+
return f"{pre_prompt}{prompt}{post_prompt}"
18+
19+
20+
def vlmsareblind_doc_to_target(doc: Dict) -> str:
21+
"""Extract the expected answer from the document."""
22+
# The answer should be a number in curly brackets like {3}
23+
answer = doc.get("answer", "")
24+
return str(answer)
25+
26+
27+
def extract_answer(response: str) -> str:
28+
"""Extract the number from a response containing {N} format."""
29+
# Look for pattern {number}
30+
match = re.search(r"\{(\d+)\}", response)
31+
if match:
32+
return match.group(0) # Return the full match including brackets
33+
34+
# If no brackets found, try to find just a number
35+
match = re.search(r"\b(\d+)\b", response)
36+
if match:
37+
return f"{{{match.group(1)}}}" # Add brackets
38+
39+
return response.strip()
40+
41+
42+
def vlmsareblind_process_result(doc: Dict, result) -> Dict:
43+
"""Process the model's response and compare with ground truth."""
44+
# Handle case where result is a list (common with generate_until)
45+
if isinstance(result, list):
46+
result = result[0] if result else ""
47+
48+
pred = extract_answer(str(result))
49+
target = doc.get("answer", "")
50+
51+
# Exact match comparison
52+
correct = pred == target
53+
54+
return {
55+
"exact_match": correct,
56+
"pred": pred,
57+
"target": target,
58+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
dataset_path: XAI/vlmsareblind
2+
task: "vlmsareblind"
3+
test_split: valid
4+
output_type: generate_until
5+
doc_to_visual: !function utils.vlmsareblind_doc_to_visual
6+
doc_to_text: !function utils.vlmsareblind_doc_to_text
7+
doc_to_target: !function utils.vlmsareblind_doc_to_target
8+
process_results: !function utils.vlmsareblind_process_result
9+
generation_kwargs:
10+
max_new_tokens: 32
11+
temperature: 0
12+
top_p: 1.0
13+
num_beams: 1
14+
do_sample: false
15+
metric_list:
16+
- metric: exact_match
17+
aggregation: mean
18+
higher_is_better: true
19+
metadata:
20+
- version: 0.0
21+
- description: "VLMs Are Blind: A benchmark testing visual reasoning capabilities of VLMs through path-counting tasks in subway connection diagrams."
22+
- reference: "https://arxiv.org/abs/2407.06581"
23+
24+
lmms_eval_specific_kwargs:
25+
default:
26+
pre_prompt: ""
27+
post_prompt: "\nAnswer with a number in curly brackets, e.g., {3}."
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
dataset_path: XAI/vlmsareblind
2+
task: "vlmsareblind_lite"
3+
test_split: valid
4+
output_type: generate_until
5+
doc_to_visual: !function utils.vlmsareblind_doc_to_visual
6+
doc_to_text: !function utils.vlmsareblind_doc_to_text
7+
doc_to_target: !function utils.vlmsareblind_doc_to_target
8+
process_results: !function utils.vlmsareblind_process_result
9+
generation_kwargs:
10+
max_new_tokens: 32
11+
temperature: 0
12+
top_p: 1.0
13+
num_beams: 1
14+
do_sample: false
15+
# Sample only 100 examples for lite version
16+
dataset_kwargs:
17+
num_examples: 100
18+
metric_list:
19+
- metric: exact_match
20+
aggregation: mean
21+
higher_is_better: true
22+
metadata:
23+
- version: 0.0
24+
- description: "VLMs Are Blind (Lite): A smaller subset for quick testing. Tests visual reasoning through path-counting in subway diagrams."
25+
- reference: "https://arxiv.org/abs/2407.06581"
26+
27+
lmms_eval_specific_kwargs:
28+
default:
29+
pre_prompt: ""
30+
post_prompt: "\nAnswer with a number in curly brackets, e.g., {3}."

lmms_eval/tasks/vstar_bench/README.md

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# V*-Bench (Visual Star Benchmark)
2+
3+
## Overview
4+
5+
V*-Bench is a visual question-answering benchmark designed to evaluate multimodal language models' capabilities in visual perception and reasoning. The benchmark focuses on assessing models' ability to accurately identify and reason about visual attributes in images through multiple-choice questions.
6+
7+
## Dataset Details
8+
9+
- **Dataset**: `lmms-lab/vstar-bench`
10+
- **Size**: 191 test samples
11+
- **Format**: Multiple-choice questions with 4 options (A, B, C, D)
12+
- **Modalities**: Image + Text
13+
14+
## Task Categories
15+
16+
The benchmark includes two main categories:
17+
18+
1. **Direct Attributes** (`vstar_bench_direct_attributes`)
19+
- Questions about direct visual properties such as colors, objects, counts, and characteristics
20+
- Examples: "What is the color of the glove?", "What is the breed of the dog?", "How many people are in the image?"
21+
22+
2. **Relative Position** (`vstar_bench_relative_position`)
23+
- Questions about spatial relationships and positioning of objects within images
24+
- Evaluates understanding of spatial concepts and object relationships
25+
26+
## Evaluation
27+
28+
### Metrics
29+
- **Overall Accuracy**: Percentage of correctly answered questions across all categories
30+
- **Category-specific Accuracy**: Accuracy for each individual category (direct_attributes, relative_position)
31+
32+
### Running the Benchmark
33+
34+
To evaluate a model on V*-Bench:
35+
36+
```bash
37+
# Run the full benchmark
38+
lmms-eval --model <model_name> --tasks vstar_bench --output_path ./results
39+
40+
# Run specific categories
41+
lmms-eval --model <model_name> --tasks vstar_bench_direct_attributes --output_path ./results
42+
lmms-eval --model <model_name> --tasks vstar_bench_relative_position --output_path ./results
43+
```
44+
45+
## Configuration
46+
47+
The benchmark uses the following configuration:
48+
- **Generation Settings**:
49+
- `max_new_tokens`: 16
50+
- `temperature`: 0
51+
- `top_p`: 1.0
52+
- `num_beams`: 1
53+
- `do_sample`: false
54+
55+
- **Prompt Template**:
56+
- Post-prompt: "\nAnswer with the option's letter from the given choices directly."
57+
58+
## Implementation Details
59+
60+
### Answer Extraction
61+
The evaluation system extracts answer letters (A, B, C, or D) from model responses using multiple patterns to handle various response formats:
62+
- Direct letter: "A"
63+
- With punctuation: "A.", "A)", "(A)"
64+
- Full answer format: "Answer: A", "The answer is A"
65+
66+
### Aggregation
67+
Results are aggregated both by category and overall, providing detailed performance metrics for different aspects of visual understanding.
68+
69+
## File Structure
70+
71+
```
72+
vstar_bench/
73+
├── __init__.py
74+
├── README.md
75+
├── _default_template_yaml # Base configuration
76+
├── vstar_bench.yaml # Main task configuration
77+
├── vstar_bench_direct_attributes.yaml
78+
├── vstar_bench_relative_position.yaml
79+
└── utils.py # Processing and evaluation functions
80+
```
81+
82+
## References
83+
84+
- Dataset: https://huggingface.co/datasets/lmms-lab/vstar-bench
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# V* Benchmark: Guided Visual Search as a Core Mechanism in Multimodal LLMs

0 commit comments

Comments
 (0)