Skip to content

Commit d3d659b

Browse files
fix(evaluation): enhance docstrings for VLM metrics and base classes
- Added detailed parameter descriptions to VQAnswer, ScoreOutput, and various metric classes in metrics_vlm.py. - Updated docstrings in base classes of vlm_base.py to include parameter details and return types. - Improved clarity and consistency across all metric-related docstrings.
1 parent 3dc944f commit d3d659b

File tree

3 files changed

+198
-19
lines changed

3 files changed

+198
-19
lines changed

src/pruna/evaluation/metrics/metrics_vlm.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,32 @@ def _process_images(images: torch.Tensor) -> List[Any]:
6060

6161
# Pydantic models for structured generation
6262
class VQAnswer(BaseModel):
63-
"""Structured output for VQA."""
63+
"""
64+
Structured output for VQA.
65+
66+
Parameters
67+
----------
68+
answer : str
69+
The VQA answer text.
70+
confidence : float, optional
71+
Confidence score. Default is 1.0.
72+
"""
6473

6574
answer: str
6675
confidence: float = 1.0
6776

6877

6978
class ScoreOutput(BaseModel):
70-
"""Structured output for scoring metrics."""
79+
"""
80+
Structured output for scoring metrics.
81+
82+
Parameters
83+
----------
84+
score : float
85+
The numeric score.
86+
reasoning : str | None, optional
87+
Optional reasoning for the score.
88+
"""
7189

7290
score: float
7391
reasoning: Optional[str] = None
@@ -89,6 +107,8 @@ class VQAMetric(StatefulMetric):
89107
90108
Parameters
91109
----------
110+
*args : Any
111+
Additional positional arguments.
92112
vlm_type : {"litellm", "transformers"}, optional
93113
VLM backend to use. Default is "litellm".
94114
model_name : str, optional
@@ -101,6 +121,8 @@ class VQAMetric(StatefulMetric):
101121
Device for transformers VLM.
102122
api_key : str | None, optional
103123
API key for litellm.
124+
call_type : str, optional
125+
Call type for the metric.
104126
**kwargs : Any
105127
Additional arguments.
106128
"""
@@ -190,10 +212,22 @@ class AlignmentScoreMetric(StatefulMetric):
190212
191213
Parameters
192214
----------
215+
*args : Any
216+
Additional positional arguments.
193217
vlm_type : {"litellm", "transformers"}, optional
194218
VLM backend. Default is "litellm".
219+
model_name : str, optional
220+
Model name. Default is "gpt-4o".
195221
structured_output : bool, optional
196222
Use structured generation. Default is True.
223+
use_outlines : bool, optional
224+
Use outlines for transformers. Default is False.
225+
device : str | torch.device | None, optional
226+
Device for transformers VLM.
227+
api_key : str | None, optional
228+
API key for litellm.
229+
call_type : str, optional
230+
Call type for the metric.
197231
**kwargs : Any
198232
Additional arguments.
199233
"""
@@ -277,6 +311,27 @@ class ImageEditScoreMetric(StatefulMetric):
277311
Reference
278312
----------
279313
VieScore: https://github.com/ByteDance/IEA-eval
314+
315+
Parameters
316+
----------
317+
*args : Any
318+
Additional positional arguments.
319+
vlm_type : {"litellm", "transformers"}, optional
320+
VLM backend. Default is "litellm".
321+
model_name : str, optional
322+
Model name. Default is "gpt-4o".
323+
structured_output : bool, optional
324+
Use structured generation. Default is True.
325+
use_outlines : bool, optional
326+
Use outlines for transformers. Default is False.
327+
device : str | torch.device | None, optional
328+
Device for transformers VLM.
329+
api_key : str | None, optional
330+
API key for litellm.
331+
call_type : str, optional
332+
Call type for the metric.
333+
**kwargs : Any
334+
Additional arguments.
280335
"""
281336

282337
scores: List[float]
@@ -361,6 +416,27 @@ class QAAccuracyMetric(StatefulMetric):
361416
362417
Uses VLM to answer questions about images.
363418
Higher scores indicate better image understanding.
419+
420+
Parameters
421+
----------
422+
*args : Any
423+
Additional positional arguments.
424+
vlm_type : {"litellm", "transformers"}, optional
425+
VLM backend. Default is "litellm".
426+
model_name : str, optional
427+
Model name. Default is "gpt-4o".
428+
structured_output : bool, optional
429+
Use structured generation. Default is True.
430+
use_outlines : bool, optional
431+
Use outlines for transformers. Default is False.
432+
device : str | torch.device | None, optional
433+
Device for transformers VLM.
434+
api_key : str | None, optional
435+
API key for litellm.
436+
call_type : str, optional
437+
Call type for the metric.
438+
**kwargs : Any
439+
Additional arguments.
364440
"""
365441

366442
scores: List[float]
@@ -437,6 +513,27 @@ class TextScoreMetric(StatefulMetric):
437513
438514
Uses VLM for OCR to extract text and compare with ground truth.
439515
Lower scores (edit distance) are better.
516+
517+
Parameters
518+
----------
519+
*args : Any
520+
Additional positional arguments.
521+
vlm_type : {"litellm", "transformers"}, optional
522+
VLM backend. Default is "litellm".
523+
model_name : str, optional
524+
Model name. Default is "gpt-4o".
525+
structured_output : bool, optional
526+
Use structured generation. Default is True.
527+
use_outlines : bool, optional
528+
Use outlines for transformers. Default is False.
529+
device : str | torch.device | None, optional
530+
Device for transformers VLM.
531+
api_key : str | None, optional
532+
API key for litellm.
533+
call_type : str, optional
534+
Call type for the metric.
535+
**kwargs : Any
536+
Additional arguments.
440537
"""
441538

442539
scores: List[float]
@@ -522,6 +619,27 @@ class VieScoreMetric(StatefulMetric):
522619
- Semantic score: How well image follows prompt
523620
- Quality score: Naturalness and artifacts
524621
- Overall: Geometric mean of semantic and quality
622+
623+
Parameters
624+
----------
625+
*args : Any
626+
Additional positional arguments.
627+
vlm_type : {"litellm", "transformers"}, optional
628+
VLM backend. Default is "litellm".
629+
model_name : str, optional
630+
Model name. Default is "gpt-4o".
631+
structured_output : bool, optional
632+
Use structured generation. Default is True.
633+
use_outlines : bool, optional
634+
Use outlines for transformers. Default is False.
635+
device : str | torch.device | None, optional
636+
Device for transformers VLM.
637+
api_key : str | None, optional
638+
API key for litellm.
639+
call_type : str, optional
640+
Call type for the metric.
641+
**kwargs : Any
642+
Additional arguments.
525643
"""
526644

527645
scores: List[float]

src/pruna/evaluation/metrics/vlm_base.py

Lines changed: 78 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,25 @@ def generate(
5252
response_format: Optional[Type[BaseModel]] = None,
5353
**kwargs: Any,
5454
) -> List[str]:
55-
"""Generate responses for images and prompts."""
55+
"""
56+
Generate responses for images and prompts.
57+
58+
Parameters
59+
----------
60+
images : List[Image.Image]
61+
List of PIL Images.
62+
prompts : List[str]
63+
List of text prompts.
64+
response_format : Type[BaseModel] | None
65+
Optional pydantic model for structured output.
66+
**kwargs : Any
67+
Additional arguments passed to the implementation.
68+
69+
Returns
70+
-------
71+
List[str]
72+
Generated responses.
73+
"""
5674
pass
5775

5876
@abstractmethod
@@ -63,7 +81,25 @@ def score(
6381
answers: List[str],
6482
**kwargs: Any,
6583
) -> List[float]:
66-
"""Score how well answers match images for given questions."""
84+
"""
85+
Score how well answers match images for given questions.
86+
87+
Parameters
88+
----------
89+
images : List[Image.Image]
90+
List of PIL Images.
91+
questions : List[str]
92+
List of questions.
93+
answers : List[str]
94+
List of expected answers.
95+
**kwargs : Any
96+
Additional arguments passed to the implementation.
97+
98+
Returns
99+
-------
100+
List[float]
101+
Scores for each image-question pair.
102+
"""
67103
pass
68104

69105

@@ -73,6 +109,15 @@ class LitellmVLM(BaseVLM):
73109
74110
Supports 100+ LLM providers (OpenAI, Anthropic, Azure, etc.)
75111
Default model is gpt-4o.
112+
113+
Parameters
114+
----------
115+
model_name : str, optional
116+
Model name (e.g., gpt-4o). Default is "gpt-4o".
117+
api_key : str | None, optional
118+
API key for the provider. Uses LITELLM_API_KEY or OPENAI_API_KEY env if None.
119+
**kwargs : Any
120+
Additional arguments passed to litellm.
76121
"""
77122

78123
def __init__(
@@ -111,6 +156,8 @@ def generate(
111156
List of text prompts.
112157
response_format : Type[BaseModel] | None
113158
Optional pydantic model for structured output.
159+
**kwargs : Any
160+
Additional arguments passed to litellm completion.
114161
115162
Returns
116163
-------
@@ -169,6 +216,8 @@ def score(
169216
List of questions.
170217
answers : List[str]
171218
List of expected answers.
219+
**kwargs : Any
220+
Additional arguments passed to generate.
172221
173222
Returns
174223
-------
@@ -196,6 +245,17 @@ class TransformersVLM(BaseVLM):
196245
VLM using HuggingFace Transformers for local inference.
197246
198247
Supports models like BLIP, LLaVA, etc.
248+
249+
Parameters
250+
----------
251+
model_name : str, optional
252+
HuggingFace model name. Default is "Salesforce/blip2-opt-2.7b".
253+
device : str | torch.device | None, optional
254+
Device for inference. Auto-detected if None.
255+
use_outlines : bool, optional
256+
Use outlines for constrained decoding. Default is False.
257+
**kwargs : Any
258+
Additional arguments passed to model generation.
199259
"""
200260

201261
def __init__(
@@ -244,20 +304,22 @@ def generate(
244304
"""
245305
Generate responses using local VLM.
246306
247-
Args:
248-
images: List of PIL Images
249-
prompts: List of text prompts
250-
response_format: Optional format constraint (e.g., "json", "integer")
251-
"""
252-
"""
307+
Parameters
308+
----------
309+
images : List[Image.Image]
310+
List of PIL Images.
311+
prompts : List[str]
312+
List of text prompts.
313+
response_format : str | None
314+
Optional format constraint (e.g., "json", "integer", "yes_no").
315+
**kwargs : Any
316+
Additional arguments passed to model generate.
253317
254-
Generate responses using local VLM.
255-
Args:
256-
images: List of PIL Images
257-
prompts: List of text prompts
258-
response_format: Optional format constraint (e.g., "json", "integer")
318+
Returns
319+
-------
320+
List[str]
321+
Generated responses.
259322
"""
260-
261323
self._load_model()
262324
results = []
263325
max_new_tokens = kwargs.get("max_new_tokens", 128)
@@ -347,6 +409,8 @@ def score(
347409
List of questions.
348410
answers : List[str]
349411
List of expected answers.
412+
**kwargs : Any
413+
Additional arguments passed to generate.
350414
351415
Returns
352416
-------

tests/style/test_docstrings.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,4 @@ def test_docstrings(file: str) -> None:
1414
file : str
1515
The import statement to check.
1616
"""
17-
# Skip metrics_vlm module as it uses a different docstring pattern for VLM parameters
18-
if "metrics_vlm" in file:
19-
pytest.skip("metrics_vlm uses custom VLM parameter documentation")
2017
check_docstrings_content(file)

0 commit comments

Comments
 (0)