Skip to content

Commit 76d573d

Browse files
Kyunnileetsunghan-wuLuodian
authored
easier code for multiple images (#879)
* add coco captioning chair * add chair recall * bootstrapping * add amber_g * amber-works * Add an easy flag to control image ordering (incomplete code) * we nned to do interleaved at the end of the day... * fix typo * update evaluator to prevent chair customized * mmbench two * clean code * file upload bug fix * enable double img mmmu * hallusion_bench * bootstrap for amber * fix: resolve conflicts and fix code review issues - Fix args.output -> args.output_path in file_utils.py - Replace hardcoded user path with reasonable default in amber_g - Remove commented-out debug print statements - Remove dead code blocks and informal comments Fixes issues identified in code review. --------- Co-authored-by: Patrick Wu <[email protected]> Co-authored-by: Brian Li <[email protected]>
1 parent 5e46cd9 commit 76d573d

File tree

12 files changed

+792
-25
lines changed

12 files changed

+792
-25
lines changed

lmms_eval/api/metrics.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,11 +531,37 @@ def bootstrap_stderr(f, xs, iters):
531531
return sample_stddev(res)
532532

533533

534+
def bootstrap_chair_metric(metric_fn, xs, iters):
535+
"for non multiprocessing for CHAIR"
536+
print(f"bootstrapping for stddev: {metric_fn.__name__}")
537+
res = []
538+
from tqdm import tqdm
539+
540+
for _ in tqdm(range(iters), desc="Bootstrap"):
541+
bootstrap_sample = random.choices(xs, k=len(xs))
542+
metric_value = metric_fn(bootstrap_sample)
543+
res.append(metric_value)
544+
545+
return sample_stddev(res)
546+
534547
def stderr_for_metric(metric, bootstrap_iters: int):
535548
if bootstrap_iters <= 0:
536549
# return no function (don't compute stderr) if bootstrap iters = 0
537550
return None
538-
551+
# for coco_cap_chair
552+
from lmms_eval.tasks.coco_cap_chair.utils import (
553+
coco_cap_chair_aggregate_results_chair_i,
554+
coco_cap_chair_aggregate_results_chair_s,
555+
coco_cap_chair_aggregate_results_recall,
556+
)
557+
# for amber_g
558+
from lmms_eval.tasks.amber_g.utils import (
559+
amber_g_aggregate_chair,
560+
amber_g_aggregate_cover,
561+
amber_g_aggregate_hal,
562+
amber_g_aggregate_cog,
563+
)
564+
539565
bootstrappable = [
540566
median,
541567
matthews_corrcoef,
@@ -544,11 +570,24 @@ def stderr_for_metric(metric, bootstrap_iters: int):
544570
bleu,
545571
chrf,
546572
ter,
573+
coco_cap_chair_aggregate_results_chair_i,
574+
coco_cap_chair_aggregate_results_chair_s,
575+
coco_cap_chair_aggregate_results_recall,
576+
amber_g_aggregate_chair,
577+
amber_g_aggregate_cover,
578+
amber_g_aggregate_hal,
579+
amber_g_aggregate_cog,
547580
]
548581

549582
if metric in bootstrappable:
550583
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
551584

585+
if hasattr(metric, '__name__'):
586+
if 'coco_cap_chair' in metric.__name__:
587+
return lambda x: bootstrap_chair_metric(metric, x, iters=bootstrap_iters)
588+
if 'amber_g' in metric.__name__ or 'amber_' in metric.__name__:
589+
return lambda x: bootstrap_chair_metric(metric, x, iters=bootstrap_iters)
590+
552591
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
553592

554593
return stderr.get(metric, None)

lmms_eval/evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ def evaluate(
576576
"doc_id": doc_id,
577577
"doc": saved_doc,
578578
"target": target,
579+
# "pred": metrics['coco_cap_chair_i']['pred'],
579580
"arguments": filtered_arguments,
580581
"resps": [req.resps for req in requests],
581582
"filtered_resps": [req.filtered_resps[filter_key] for req in requests],

lmms_eval/models/simple/llava_onevision1_5.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
max_image_size: Optional[int] = None, # Only applicable if use_custom_video_loader is True
4747
system_prompt: Optional[str] = "You are a helpful assistant.",
4848
interleave_visuals: Optional[bool] = False,
49+
image_first: Optional[bool] = True,
4950
reasoning_prompt: Optional[str] = None,
5051
max_length: int = 2048,
5152
**kwargs,
@@ -86,7 +87,7 @@ def __init__(
8687
self.max_pixels = max_pixels
8788
self.min_pixels = min_pixels
8889
self.max_num_frames = max_num_frames
89-
90+
self.image_first = image_first
9091
if reasoning_prompt:
9192
self.reasoning_prompt = reasoning_prompt.replace("\\n", "\n")
9293
else:
@@ -236,12 +237,20 @@ def _collate(x):
236237
processed_visuals.append({"type": "image", "image": visual.convert("RGB")})
237238

238239
if self.interleave_visuals is False:
239-
message.append(
240-
{
241-
"role": "user",
242-
"content": processed_visuals + [{"type": "text", "text": context}],
243-
}
244-
)
240+
if self.image_first:
241+
message.append(
242+
{
243+
"role": "user",
244+
"content": processed_visuals + [{"type": "text", "text": context}],
245+
}
246+
)
247+
else:
248+
message.append(
249+
{
250+
"role": "user",
251+
"content": [{"type": "text", "text": context}] + processed_visuals,
252+
}
253+
)
245254
else: # currently support find <image x> in the context
246255
image_placeholders = re.findall(r"<image \d+>", context)
247256
content_parts = []

lmms_eval/models/simple/vllm.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(
156156
chat_template: Optional[str] = None,
157157
min_image_pixels: int = 28, # minimum image dimension, required for Qwen 2/2.5-VL models
158158
disable_log_stats: bool = False,
159+
image_first: bool = False,
159160
**kwargs,
160161
) -> None:
161162
super().__init__()
@@ -167,6 +168,7 @@ def __init__(
167168
self.chat_template = chat_template
168169
self.min_image_pixels = min_image_pixels
169170
self.data_parallel_size = data_parallel_size
171+
self.image_first = image_first
170172
# Qwen 2/2.5-VL models enforce minimum image dimensions
171173
self._enforce_image_resize = self._is_qwen_vl_model(model)
172174

@@ -338,11 +340,14 @@ def generate_until(self, requests) -> List[str]:
338340
imgs.append(task.result())
339341

340342
messages = [{"role": "user", "content": []}]
341-
# Add images first, then text
342-
for img in self.flatten(imgs):
343-
messages[0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}})
344-
messages[0]["content"].append({"type": "text", "text": contexts})
345-
343+
if self.image_first:
344+
for img in self.flatten(imgs):
345+
messages[0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}})
346+
messages[0]["content"].append({"type": "text", "text": contexts})
347+
else:
348+
messages[0]["content"].append({"type": "text", "text": contexts})
349+
for img in self.flatten(imgs):
350+
messages[0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}})
346351
batched_messages.append(messages)
347352

348353
sampling_params = SamplingParams(**params)

lmms_eval/tasks/_task_utils/file_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def generate_submission_file(file_name, args, subpath="submissions"):
5-
if args.output_path is None:
5+
if args is None or args.output_path is None:
66
# If no output path is specified, use current directory
77
path = subpath
88
else:
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# AMBER-G (Generative Task) Evaluation Configuration
2+
# Based on: https://github.com/junyangwang0410/AMBER
3+
# Dataset includes: images, questions, and complete ground truth annotations
4+
5+
dataset_path: Kyunnilee/amber_g # use this dataset
6+
dataset_kwargs:
7+
trust_remote_code: true
8+
task: "amber_g"
9+
output_type: generate_until
10+
11+
doc_to_visual: !function utils.amber_g_doc_to_visual
12+
doc_to_text: !function utils.amber_g_doc_to_text
13+
doc_to_target: "truth"
14+
test_split: train
15+
16+
generation_kwargs:
17+
max_new_tokens: 2048
18+
temperature: 0
19+
top_p: 1.0
20+
num_beams: 1
21+
do_sample: false
22+
until: [] # really important!!! the default would be ["\n\n"] and that will cause truncation
23+
24+
process_results: !function utils.amber_g_process_result
25+
26+
# AMBER-G Metrics:
27+
metric_list:
28+
- metric: amber_chair
29+
aggregation: !function utils.amber_g_aggregate_chair
30+
higher_is_better: false
31+
- metric: amber_cover
32+
aggregation: !function utils.amber_g_aggregate_cover
33+
higher_is_better: true
34+
- metric: amber_hal
35+
aggregation: !function utils.amber_g_aggregate_hal
36+
higher_is_better: false
37+
- metric: amber_cog
38+
aggregation: !function utils.amber_g_aggregate_cog
39+
higher_is_better: true
40+
41+
metadata:
42+
- version: 0.0

0 commit comments

Comments
 (0)