Skip to content

Commit c2ea941

Browse files
author
Nihal John George
committed
vision-debug: add basic correctness check using existing datasets
1 parent d23d6f5 commit c2ea941

File tree

1 file changed

+147
-0
lines changed

1 file changed

+147
-0
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import json
2+
import random
3+
import sys
4+
import base64
5+
6+
from io import BytesIO
7+
from pathlib import Path
8+
from typing import Any, Dict, List, Optional, Tuple, Union
9+
10+
import numpy as np
11+
import tvm
12+
from datasets import load_dataset
13+
from tvm import relax
14+
from tvm.contrib import tvmjs
15+
from tvm.runtime import Device, Module, Object, ShapeTuple
16+
from tvm.runtime.relax_vm import VirtualMachine
17+
18+
from mlc_llm import MLCEngine
19+
from mlc_llm.conversation_template import ConvTemplateRegistry
20+
from mlc_llm.interface.help import HELP
21+
from mlc_llm.protocol.mlc_chat_config import MLCChatConfig
22+
from mlc_llm.serve import data, engine_utils
23+
from mlc_llm.support.argparse import ArgumentParser
24+
from mlc_llm.support.auto_device import detect_device
25+
from mlc_llm.support.style import green, red
26+
from mlc_llm.tokenizers import Tokenizer
27+
28+
prompt_phi_3_5_v_few_shot = """Question:
29+
Which of the following is the body cavity that contains the pituitary gland?
30+
Options:
31+
A. Abdominal
32+
B. Cranial
33+
C. Pleural
34+
D. Spinal
35+
Answer: B
36+
Question:
37+
Where was the most famous site of the mystery cults in Greece?
38+
Options:
39+
A. Ephesus
40+
B. Corinth
41+
C. Athens
42+
D. Eleusis
43+
Answer: D
44+
45+
"""
46+
47+
prompt_phi_3_5_v_zero_shot = """"""
48+
49+
def encode_image(image):
50+
rgb_image = image
51+
buffer = BytesIO()
52+
rgb_image.save(buffer, format="PNG")
53+
buffer.seek(0)
54+
image_str = base64.b64encode(buffer.read()).decode('utf-8')
55+
return image_str
56+
57+
def construct_prompt_mmmu(ex, prompt_prefix=prompt_phi_3_5_v_zero_shot):
58+
overall_prompt = prompt_prefix
59+
if 'question' in ex:
60+
overall_prompt += ex['question'] + "\n"
61+
if 'options' in ex:
62+
options = eval(ex['options'])
63+
for oi, option in enumerate(options):
64+
overall_prompt += f"{chr(oi+65)}: {option}\n"
65+
overall_prompt += "Answer: "
66+
return overall_prompt
67+
68+
def eval_mmmu(model, engine: MLCEngine, prompt=prompt_phi_3_5_v_zero_shot, temperature=0.0):
69+
slices = ["Accounting"]
70+
slice_correct = []
71+
slice_total = []
72+
for si, sl in enumerate(slices):
73+
ds = load_dataset("MMMU/MMMU", sl)
74+
slice_correct_here = 0
75+
slice_total_here = 0
76+
for exi in range(len(ds['validation'])):
77+
ex = ds['validation'][exi]
78+
preproc_ex = construct_prompt_mmmu(ex, prompt_phi_3_5_v_zero_shot)
79+
base64_image = encode_image(ex["image_1"])
80+
response = engine.chat.completions.create(
81+
messages=[
82+
{
83+
"role": "user",
84+
"content": [
85+
{
86+
"type":"image_url",
87+
"image_url": {"url":f"data:image/jpeg;base64,{base64_image}"}
88+
},
89+
{
90+
"type":"text",
91+
"text":preproc_ex
92+
}
93+
]
94+
}
95+
],
96+
model=model,
97+
stream=False,
98+
temperature=temperature,
99+
)
100+
ans = response.choices[0].message.content
101+
if ans.strip()[:1] == ds['validation'][exi]['answer'].strip():
102+
slice_correct_here += 1
103+
print("Correct")
104+
else:
105+
print("Wrong")
106+
slice_total_here += 1
107+
108+
slice_correct.append(slice_correct_here)
109+
slice_total.append(slice_total_here)
110+
print(f"Slice: {sl} ; Statistics Below\nCorrect: {slice_correct_here}\nTotal: {slice_total_here}\nAccuracy: {slice_correct_here/slice_total_here}")
111+
112+
overall_total = sum(slice_total)
113+
correct_total = sum(slice_correct)
114+
print(f"Overall Statistics Below\nCorrect: {correct_total}\nTotal: {overall_total}\nAccuracy: {correct_total/overall_total}")
115+
116+
def main():
117+
"""The main function to start a DebugChat CLI"""
118+
parser = ArgumentParser("MLC LLM Correctness Benchmark")
119+
parser.add_argument(
120+
"--model",
121+
type=str,
122+
help="An MLC model directory that contains `mlc-chat-config.json`",
123+
required=True,
124+
)
125+
parser.add_argument(
126+
"--model-lib",
127+
type=str,
128+
help="The full path to the model library file to use (e.g. a ``.so`` file).",
129+
required=True,
130+
)
131+
parser.add_argument(
132+
"--device",
133+
type=str,
134+
default="auto",
135+
help=HELP["device_compile"] + ' (default: "%(default)s")',
136+
)
137+
parser.add_argument(
138+
"--temperature",
139+
default=0.0,
140+
help="temperature for generation"
141+
)
142+
parsed = parser.parse_args()
143+
engine = MLCEngine(parsed.model, model_lib=parsed.model_lib)
144+
eval_mmmu(parsed.model, engine, temperature=parsed.temperature)
145+
146+
if __name__ == "__main__":
147+
main()

0 commit comments

Comments
 (0)