Skip to content

Commit 5aae37a

Browse files
feat(evaluation): introduce new VLM metrics and integration tests
- Added new metrics: AlignmentScoreMetric, ImageEditScoreMetric, QAAccuracyMetric, TextScoreMetric, VieScoreMetric, and VQAMetric for comprehensive evaluation of image-text alignment and quality. - Implemented integration test script for VLM metrics, allowing testing against both Litellm and Transformers backends. - Updated pyproject.toml to reflect new dependencies and changes in optional dependencies. - Added documentation for prompt comparisons between Pruna and InferBench implementations.
1 parent d3d659b commit 5aae37a

File tree

13 files changed

+1349
-761
lines changed

13 files changed

+1349
-761
lines changed
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# VLM Metrics: Prompt Comparison (Pruna vs InferBench)
2+
3+
Overview of prompt differences between Pruna's VLM metrics and InferBench's implementation.
4+
5+
---
6+
7+
## Summary Table
8+
9+
| Metric | Pruna | InferBench | Key Differences |
10+
|--------|-------|------------|-----------------|
11+
| **Alignment Score** | Single generic question | Multi-question with dependencies | Pruna: 1 prompt; InferBench: N questions from OneIG JSON |
12+
| **VQA** | Same as Alignment (reused) | Dedicated template | Both use "Does this show X? Yes/No" |
13+
| **Text Score** | Short OCR prompt | Detailed OCR prompt | InferBench: longer, explicit format rules |
14+
| **Img Edit Score** | Simple 0–10 rating | Full judge prompts from ImgEdit repo | InferBench: 5-point multi-criteria per edit type |
15+
| **VieScore** | Two short prompts | Long SC + PQ prompts | InferBench: detailed rules, JSON output |
16+
| **QA Accuracy** | Generic "What is in this image?" | Benchmark-specific questions | Different use cases |
17+
| **VLM Base (score)** | Litellm: "Answer Yes or No" / Transformers: "Question: X Answer:" | Generation + logprobs fallback | Response format differs |
18+
19+
---
20+
21+
## 1. Alignment Score
22+
23+
### Pruna
24+
- **Question**: `Does this image show "{prompt}"? Answer Yes or No.`
25+
- **Expected answer**: `Yes`
26+
- **Scope**: Single prompt–image alignment per sample
27+
- **Source**: `metric_alignment_score.py`, `metric_vqa.py` (same logic)
28+
29+
### InferBench
30+
- **Questions**: From OneIG JSON (e.g. `anime.json`, `human.json`, `object.json`)
31+
- **Template**: `{question}. Only answer 'Yes' or 'No'. Do not answer anything else.`
32+
- **Examples**: "Are there boys?", "Are there four boys?", "Is there a nun?", etc.
33+
- **Dependencies**: Parent–child question graph; child scores set to 0 if parent is No
34+
- **Scope**: 9–20 questions per image, dependency-aware aggregation
35+
- **Source**: `alignment_score.py`, `oneig.py` (benchmark)
36+
37+
---
38+
39+
## 2. VQA (Visual Question Answering)
40+
41+
### Pruna
42+
- Same as Alignment Score: `Does this image show "{prompt}"? Answer Yes or No.`
43+
- Used for both `alignment_score` and `vqa` metrics
44+
45+
### InferBench
46+
- **Template**: `Does this figure show "{prompt}"? Please answer yes or no.`
47+
- **Expected answer**: `Yes`
48+
- **Difference**: "figure" vs "image"; "Please answer yes or no" vs "Answer Yes or No"
49+
- **Source**: `vqa.py`
50+
51+
---
52+
53+
## 3. Text Score (OCR)
54+
55+
### Pruna
56+
- **Prompt**: `Extract all text from this image. If no text, say 'No text'.`
57+
- **Output use**: Binary check (no text → score 10.0, else 0.0) — *Note: Pruna text_score appears to use edit distance logic elsewhere; this prompt is for OCR extraction*
58+
- **Source**: `metric_text_score.py`
59+
60+
### InferBench
61+
- **Prompt**:
62+
```
63+
Extract all text visible in this image. Include logos, stylized fonts, handwritten text, and non-standard typography.
64+
Return only the extracted text, exactly as it appears—no preamble, explanation, or markdown.
65+
Preserve words, numbers, punctuation, and spacing. If no text is recognized, reply with exactly: No text recognized
66+
```
67+
- **Post-processing**: Hallucination removal ("addCriterion", "No text recognized"), Levenshtein vs ground truth, word accuracy
68+
- **Source**: `text_score.py`
69+
70+
---
71+
72+
## 4. Image Edit Score
73+
74+
### Pruna
75+
- **Question**: `Rate 0-10: Does this image show "{prompt}"? Reply with a number.`
76+
- **Input**: Single edited image + prompt
77+
- **Output**: 0–10 score, normalized to [0, 1]
78+
- **Source**: `metric_img_edit_score.py`
79+
80+
### InferBench
81+
- **Input**: Original image + edited image + edit instruction
82+
- **Judge prompts**: Fetched from ImgEdit repo (`prompts.json`) per edit type (replace, add, remove, adjust, style, extract, background, compose)
83+
- **Format**: Long multi-criteria prompts (5-point scale):
84+
- Prompt Compliance (1–5)
85+
- Visual Naturalness / Seamlessness (1–5)
86+
- Physical & Detail Integrity (1–5)
87+
- **Output**: Average of 3 scores, parsed from `"Prompt Compliance: N\nVisual Naturalness: N\n..."` format
88+
- **Source**: `img_edit_score.py`, `img_edit.py` (benchmark), external `prompts.json`
89+
90+
---
91+
92+
## 5. VieScore
93+
94+
### Pruna
95+
- **Semantic**: `Rate 0-10: Does this image show "{prompt}"?`
96+
- **Quality**: `Rate 0-10: How natural is this image? Any artifacts?`
97+
- **Aggregation**: `sqrt(semantic * quality) / 10`
98+
- **Source**: `metric_viescore.py`
99+
100+
### InferBench
101+
- **SC (Semantic/Compliance)**: Long prompt with rules for editing success + overediting
102+
- Two images (original + edited)
103+
- `score1` = editing success (0–10), `score2` = overediting (0–10)
104+
- Output: `[score1, score2]`
105+
- **PQ (Perceptual Quality)**: Long prompt for naturalness + artifacts
106+
- Single image
107+
- `naturalness` (0–10), `artifacts` (0–10)
108+
- Output: `[naturalness, artifacts]`
109+
- **Aggregation**: `min(SC_scores)`, `min(PQ_scores)`, `overall = sqrt(SC * PQ)`
110+
- **Context**: "You are a professional digital artist..." + JSON output format
111+
- **Source**: `viescore.py`
112+
113+
---
114+
115+
## 6. QA Accuracy
116+
117+
### Pruna
118+
- **Question**: `What is in this image? Answer:`
119+
- **Scoring**: 1.0 if non-empty response, else 0.0
120+
- **Use**: Generic image understanding check
121+
- **Source**: `metric_qa_accuracy.py`
122+
123+
### InferBench
124+
- **Questions**: From GenEval metadata (e.g. "Does the image show at least one red apple?", "Does the image show exactly 3 cats?")
125+
- **Template**: `{question} Please answer yes or no.`
126+
- **Expected answers**: `Yes` for all (benchmark-specific)
127+
- **Scoring**: Accuracy over N questions, n_correct, n_incorrect
128+
- **Source**: `qa_accuracy.py`, `geneval.py` (benchmark)
129+
130+
---
131+
132+
## 7. VLM Base Layer (Score Method)
133+
134+
### Pruna – LitellmVLM & TransformersVLM
135+
- **Prompt**: `{question} Please answer yes or no.`
136+
- **Scoring**: `1.0 if answer.lower() in response else 0.0`
137+
- **Scoring**: Same substring check
138+
- **Source**: `vlm_base.py` line 371
139+
140+
### InferBench – OpenAIAPIVLM
141+
- **Scoring**: Prefers logprobs (Yes/No token probabilities) when available
142+
- **Fallback**: Generation + substring check ("yes"/"no" in response)
143+
- **No prompt suffix**: Question passed as-is; metrics add their own suffix
144+
- **Source**: `api_vlm_base.py`
145+
146+
---
147+
148+
## Recommendations
149+
150+
1. **Alignment / VQA**: InferBench’s multi-question + dependency setup is more detailed; Pruna’s single-question version is simpler. For OneIG-style benchmarks, InferBench’s approach is required.
151+
152+
2. **Text Score**: InferBench’s OCR prompt is more explicit and robust; Pruna now uses InferBench-style OCR prompt and supports ground-truth edit distance when gt contains text_content.
153+
154+
3. **Img Edit Score**: InferBench uses full ImgEdit judge prompts; Pruna uses an improved single 0–10 rating with explicit scale instructions. For ImgEdit benchmarks, InferBench’s prompts are necessary.
155+
156+
4. **VieScore**: InferBench’s SC+PQ prompts match the original VieScore design. Pruna’s uses improved explicit 0–10 scale prompts.
157+
158+
5. **VLM Base**: Pruna now uses unified "Please answer yes or no." suffix for both Litellm and Transformers.

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,8 @@ dependencies = [
142142

143143
[project.optional-dependencies]
144144
evaluation = [
145-
"pydantic>=2.0.0",
145+
"outlines>1.2.0,<2.0.0",
146146
"litellm>=1.0.0",
147-
"transformers>=4.40.0",
148-
"accelerate>=0.20.0",
149147
]
150148

151149
stable-fast = [

src/pruna/evaluation/metrics/__init__.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,22 @@
1515
from pruna.evaluation.metrics.registry import MetricRegistry # isort:skip
1616

1717
from pruna.evaluation.metrics.aesthetic_laion import AestheticLAION
18+
from pruna.evaluation.metrics.metric_alignment_score import AlignmentScoreMetric
1819
from pruna.evaluation.metrics.metric_cmmd import CMMD
1920
from pruna.evaluation.metrics.metric_dino_score import DinoScore
2021
from pruna.evaluation.metrics.metric_elapsed_time import LatencyMetric, ThroughputMetric, TotalTimeMetric
2122
from pruna.evaluation.metrics.metric_energy import CO2EmissionsMetric, EnergyConsumedMetric
23+
from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric
2224
from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric
2325
from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric
2426
from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore
27+
from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric
2528
from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric
29+
from pruna.evaluation.metrics.metric_text_score import TextScoreMetric
2630
from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper
27-
from pruna.evaluation.metrics.metrics_vlm import (
28-
AlignmentScoreMetric,
29-
ImageEditScoreMetric,
30-
QAAccuracyMetric,
31-
TextScoreMetric,
32-
VieScoreMetric,
33-
VQAMetric,
34-
)
31+
from pruna.evaluation.metrics.metric_viescore import VieScoreMetric
32+
from pruna.evaluation.metrics.metric_vqa import VQAMetric
33+
from pruna.evaluation.metrics.vlm_base import BaseVLM, LitellmVLM, TransformersVLM, get_vlm
3534

3635
__all__ = [
3736
"MetricRegistry",
@@ -57,4 +56,8 @@
5756
"QAAccuracyMetric",
5857
"TextScoreMetric",
5958
"VieScoreMetric",
59+
"BaseVLM",
60+
"LitellmVLM",
61+
"TransformersVLM",
62+
"get_vlm",
6063
]
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Alignment Score metric using VLM for image-text alignment evaluation."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Any, List, Literal, Optional
20+
21+
import numpy as np
22+
import torch
23+
24+
from pruna.engine.utils import set_to_best_available_device
25+
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
26+
from pruna.evaluation.metrics.metric_vlm_utils import YesNoAnswer, _process_images
27+
from pruna.evaluation.metrics.registry import MetricRegistry
28+
from pruna.evaluation.metrics.result import MetricResult
29+
from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor
30+
from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm
31+
32+
33+
@MetricRegistry.register("alignment_score")
34+
class AlignmentScoreMetric(StatefulMetric):
35+
"""
36+
Alignment Score metric using VLM.
37+
38+
Assesses how well generated images match text prompts through structured questioning.
39+
Higher scores indicate better alignment.
40+
41+
Parameters
42+
----------
43+
*args : Any
44+
Additional positional arguments.
45+
vlm : BaseVLM | None, optional
46+
Custom VLM instance. If provided, vlm_type and model_name are ignored.
47+
vlm_type : {"litellm", "transformers"}, optional
48+
VLM backend. Default is "litellm".
49+
model_name : str, optional
50+
Model name. Default is "gpt-4o".
51+
vlm_kwargs : dict, optional
52+
Extra kwargs for VLM init (e.g. model_load_kwargs for transformers).
53+
structured_output : bool, optional
54+
Use structured generation. Default is True.
55+
use_outlines : bool, optional
56+
Use outlines for transformers. Default is False.
57+
device : str | torch.device | None, optional
58+
Device for transformers VLM.
59+
api_key : str | None, optional
60+
API key for litellm.
61+
call_type : str, optional
62+
Call type for the metric.
63+
**kwargs : Any
64+
Additional arguments.
65+
"""
66+
67+
scores: List[float]
68+
default_call_type: str = "y"
69+
higher_is_better: bool = True
70+
metric_name: str = "alignment_score"
71+
runs_on: List[str] = ["cpu"]
72+
73+
def __init__(
74+
self,
75+
*args,
76+
vlm: Optional[BaseVLM] = None,
77+
vlm_type: Literal["litellm", "transformers"] = "litellm",
78+
model_name: str = "gpt-4o",
79+
vlm_kwargs: Optional[dict] = None,
80+
structured_output: bool = True,
81+
use_outlines: bool = False,
82+
device=None,
83+
api_key: Optional[str] = None,
84+
call_type: str = SINGLE,
85+
**kwargs,
86+
):
87+
super().__init__(device=device)
88+
self.device = set_to_best_available_device(device)
89+
90+
self.vlm = get_vlm(
91+
vlm=vlm,
92+
vlm_type=vlm_type,
93+
model_name=model_name,
94+
device=device,
95+
api_key=api_key,
96+
use_outlines=use_outlines,
97+
**(vlm_kwargs or {}),
98+
)
99+
self.response_format = (
100+
YesNoAnswer if structured_output and vlm_type == "litellm" else
101+
("yes_no" if structured_output and vlm_type == "transformers" else None)
102+
)
103+
104+
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
105+
self.add_state("scores", [])
106+
107+
def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None:
108+
inputs = metric_data_processor(x, gt, outputs, self.call_type)
109+
images = _process_images(inputs[0])
110+
prompts = x if isinstance(x, list) else [""] * len(images)
111+
for i, image in enumerate(images):
112+
prompt = prompts[i] if i < len(prompts) else ""
113+
question = f'Does this image show "{prompt}"?'
114+
score = self.vlm.score([image], [question], ["Yes"], response_format=self.response_format)[0]
115+
self.scores.append(score)
116+
117+
def compute(self) -> MetricResult:
118+
if not self.scores:
119+
return MetricResult(self.metric_name, self.__dict__, 0.0)
120+
return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores)))

0 commit comments

Comments
 (0)