Skip to content

Commit 209e12e

Browse files
feat(metrics): paper docstring fixes, VQA use_probability default, vlm docstrings
- VieScore: docstring arXiv:2312.14867, TIGER-AI-Lab/VIEScore - Image Edit Score: docstring EditScore, ADIEE - VQA: docstring arXiv:2404.01291, use_probability=True default - vlm_base: full Parameters/Returns for score(), _score_with_logprobs Made-with: Cursor
1 parent deab4b5 commit 209e12e

File tree

4 files changed

+116
-15
lines changed

4 files changed

+116
-15
lines changed

src/pruna/evaluation/metrics/metric_img_edit_score.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""
1616
Image Edit Score metric.
1717
18-
Reference: VieScore https://github.com/ByteDance/IEA-eval
18+
VLM-based instruction-following score for image editing. Evaluates how well an edited image
19+
follows the given editing instruction on a 0-10 scale. Related work: EditScore (arXiv:2509.23909),
20+
ADIEE (ICCV 2025).
1921
"""
2022

2123
from __future__ import annotations
@@ -40,8 +42,10 @@ class ImageEditScoreMetric(StatefulMetric):
4042
"""
4143
Image Edit Score metric.
4244
43-
Evaluates how well an image was edited based on editing instructions.
44-
Higher scores indicate better editing quality.
45+
VLM-based instruction-following score for image editing. Evaluates how well an edited image
46+
follows the given editing instruction. Higher scores indicate better editing quality.
47+
48+
Related work: EditScore (arXiv:2509.23909), ADIEE (ICCV 2025).
4549
4650
Parameters
4751
----------

src/pruna/evaluation/metrics/metric_viescore.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414

1515
"""
16-
VieScore metric for evaluating image quality (semantic + quality).
16+
VIEScore metric for evaluating conditional image synthesis (semantic + quality).
1717
18-
Reference: VieScore https://github.com/ByteDance/IEA-eval
18+
Reference: VIEScore: Towards Explainable Metrics for Conditional Image Synthesis Evaluation
19+
(ACL 2024) - https://arxiv.org/abs/2312.14867, https://github.com/TIGER-AI-Lab/VIEScore
1920
"""
2021

2122
from __future__ import annotations
@@ -39,7 +40,7 @@
3940
@MetricRegistry.register("viescore")
4041
class VieScoreMetric(StatefulMetric):
4142
"""
42-
VieScore metric for evaluating image quality (semantic + quality).
43+
VIEScore metric for evaluating conditional image synthesis (semantic + quality).
4344
4445
Uses VLM to assess both semantic alignment and visual quality.
4546
Higher scores indicate better overall quality.
@@ -49,6 +50,12 @@ class VieScoreMetric(StatefulMetric):
4950
- Quality score: Naturalness and artifacts
5051
- Overall: Geometric mean of semantic and quality
5152
53+
References
54+
----------
55+
VIEScore: Towards Explainable Metrics for Conditional Image Synthesis Evaluation (ACL 2024)
56+
https://arxiv.org/abs/2312.14867
57+
https://github.com/TIGER-AI-Lab/VIEScore
58+
5259
Parameters
5360
----------
5461
*args : Any

src/pruna/evaluation/metrics/metric_vqa.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
"""
1616
VQA (Visual Question Answering) metric.
1717
18-
Reference: VQAScore https://arxiv.org/abs/2310.08868
18+
Reference: VQAScore - Evaluating Text-to-Visual Generation with Image-to-Text Generation
19+
https://arxiv.org/abs/2404.01291
20+
21+
Note: VQAScore uses P(Yes) (probability of "Yes" answer) for ranking. With litellm,
22+
use_probability=True (default) requests logprobs for soft scores when the provider supports it.
23+
Set use_probability=False for binary 0/1. TransformersVLM always uses binary.
1924
"""
2025

2126
from __future__ import annotations
@@ -39,9 +44,12 @@ class VQAMetric(StatefulMetric):
3944
"""
4045
VQA (Visual Question Answering) metric.
4146
42-
Uses VLM to answer questions about images and compare with expected answers.
47+
Uses VLM to answer "Does this image show '{prompt}'?" and scores alignment.
4348
Higher scores indicate better image-text alignment.
4449
50+
VQAScore (arXiv:2404.01291) uses P(Yes) for ranking. Default use_probability=True
51+
with litellm requests logprobs for soft scores when supported.
52+
4553
Parameters
4654
----------
4755
*args : Any
@@ -64,6 +72,9 @@ class VQAMetric(StatefulMetric):
6472
API key for litellm.
6573
call_type : str, optional
6674
Call type for the metric.
75+
use_probability : bool, optional
76+
If True, use P(Yes) when backend supports logprobs (litellm). Otherwise binary 0/1.
77+
Default is True for paper alignment.
6778
**kwargs : Any
6879
Additional arguments.
6980
"""
@@ -86,11 +97,13 @@ def __init__(
8697
device=None,
8798
api_key: Optional[str] = None,
8899
call_type: str = SINGLE,
100+
use_probability: bool = True,
89101
**kwargs,
90102
):
91103
super().__init__(device=device)
92104
self.device = set_to_best_available_device(device)
93105
self.structured_output = structured_output
106+
self.use_probability = use_probability
94107

95108
self.vlm = get_vlm(
96109
vlm=vlm,
@@ -117,7 +130,13 @@ def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.T
117130
for i, image in enumerate(images):
118131
prompt = prompts[i] if i < len(prompts) else ""
119132
question = f'Does this image show "{prompt}"?'
120-
score = self.vlm.score([image], [question], ["Yes"], response_format=self.response_format)[0]
133+
score = self.vlm.score(
134+
[image],
135+
[question],
136+
["Yes"],
137+
response_format=self.response_format,
138+
use_probability=self.use_probability,
139+
)[0]
121140
self.scores.append(score)
122141

123142
def compute(self) -> MetricResult:

src/pruna/evaluation/metrics/vlm_base.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import base64
3030
import io
31+
import math
3132
import os
3233
from abc import ABC, abstractmethod
3334
from typing import Any, List, Literal, Optional, Type, TypeVar
@@ -129,6 +130,7 @@ def score(
129130
images: List[Image.Image],
130131
questions: List[str],
131132
answers: List[str],
133+
use_probability: bool = False,
132134
**kwargs: Any,
133135
) -> List[float]:
134136
"""
@@ -142,13 +144,15 @@ def score(
142144
List of questions.
143145
answers : List[str]
144146
List of expected answers.
147+
use_probability : bool, optional
148+
If True and supported, return P(expected answer) instead of binary 0/1.
145149
**kwargs : Any
146150
Additional arguments passed to the implementation.
147151
148152
Returns
149153
-------
150154
List[float]
151-
Scores for each image-question pair.
155+
Scores for each image-question pair (0-1, or probability when use_probability).
152156
"""
153157
pass
154158

@@ -253,11 +257,15 @@ def score(
253257
images: List[Image.Image],
254258
questions: List[str],
255259
answers: List[str],
260+
use_probability: bool = False,
256261
**kwargs: Any,
257262
) -> List[float]:
258263
"""
259264
Score how well answers match images for given questions.
260265
266+
When use_probability=True, requests logprobs from the API and returns P(expected).
267+
Falls back to binary 0/1 if logprobs not available.
268+
261269
Parameters
262270
----------
263271
images : List[Image.Image]
@@ -266,22 +274,80 @@ def score(
266274
List of questions.
267275
answers : List[str]
268276
List of expected answers.
277+
use_probability : bool, optional
278+
If True, return P(expected) from logprobs when available. Default is False.
269279
**kwargs : Any
270-
Additional arguments passed to generate.
280+
Additional arguments passed to litellm completion.
271281
272282
Returns
273283
-------
274284
List[float]
275-
Scores for each image-question pair.
285+
Scores for each image-question pair (0-1, or probability when use_probability).
276286
"""
277287
scores = []
278288
for image, question, answer in zip(images, questions, answers):
279289
prompt = f"{question} Please answer yes or no."
280-
response = self.generate([image], [prompt], **kwargs)[0].lower()
281-
score = 1.0 if answer.lower() in response else 0.0
290+
if use_probability:
291+
score = self._score_with_logprobs(image, prompt, answer, **kwargs)
292+
else:
293+
response = self.generate([image], [prompt], **kwargs)[0].lower()
294+
score = 1.0 if answer.lower() in response else 0.0
282295
scores.append(score)
283296
return scores
284297

298+
def _score_with_logprobs(self, image: Image.Image, prompt: str, expected: str, **kwargs: Any) -> float:
299+
"""
300+
Get P(expected) from logprobs when available.
301+
302+
Parameters
303+
----------
304+
image : Image.Image
305+
PIL Image to score.
306+
prompt : str
307+
Question prompt.
308+
expected : str
309+
Expected answer (e.g., "Yes").
310+
**kwargs : Any
311+
Additional arguments passed to litellm completion.
312+
313+
Returns
314+
-------
315+
float
316+
Probability of expected answer (0-1), or binary 0/1 on fallback.
317+
"""
318+
content = [
319+
{"type": "text", "text": prompt},
320+
{"type": "image_url", "image_url": {"url": self._image_to_data_url(image)}},
321+
]
322+
completion_kwargs = {
323+
"model": self.model_name,
324+
"messages": [{"role": "user", "content": content}],
325+
"api_key": self.api_key,
326+
"logprobs": True,
327+
"top_logprobs": 5,
328+
**self.extra_kwargs,
329+
**kwargs,
330+
}
331+
try:
332+
response = self._litellm.completion(**completion_kwargs)
333+
choice = response.choices[0]
334+
logprobs = getattr(choice, "logprobs", None) or getattr(choice.message, "logprobs", None)
335+
if logprobs and hasattr(logprobs, "content"):
336+
for tok in (logprobs.content or []):
337+
top = getattr(tok, "top_logprobs", None) or []
338+
for t in top:
339+
token_str = getattr(t, "token", "") or str(t).lower()
340+
if token_str and expected.lower() in token_str.lower():
341+
logprob = float(getattr(t, "logprob", -1e9) or -1e9)
342+
return min(1.0, max(0.0, math.exp(logprob)))
343+
content_str = (choice.message.content or "").lower()
344+
if expected.lower() in content_str:
345+
return 1.0
346+
return 0.0
347+
except Exception:
348+
response = self.generate([image], [prompt], **kwargs)[0].lower()
349+
return 1.0 if expected.lower() in response else 0.0
350+
285351
def _image_to_data_url(self, image: Image.Image) -> str:
286352
buffer = io.BytesIO()
287353
image.save(buffer, format="PNG")
@@ -458,11 +524,14 @@ def score(
458524
images: List[Image.Image],
459525
questions: List[str],
460526
answers: List[str],
527+
use_probability: bool = False,
461528
**kwargs: Any,
462529
) -> List[float]:
463530
"""
464531
Score how well answers match images for given questions.
465532
533+
use_probability is not supported for TransformersVLM; uses binary 0/1.
534+
466535
Parameters
467536
----------
468537
images : List[Image.Image]
@@ -471,13 +540,15 @@ def score(
471540
List of questions.
472541
answers : List[str]
473542
List of expected answers.
543+
use_probability : bool, optional
544+
Ignored; TransformersVLM always uses binary 0/1.
474545
**kwargs : Any
475546
Additional arguments passed to generate.
476547
477548
Returns
478549
-------
479550
List[float]
480-
Scores for each image-question pair.
551+
Scores for each image-question pair (0 or 1).
481552
"""
482553
scores = []
483554
for image, question, answer in zip(images, questions, answers):

0 commit comments

Comments
 (0)