Skip to content

Commit aa80886

Browse files
update eval_answer for multi turn (#424)
1 parent 5f2895f commit aa80886

File tree

1 file changed

+38
-7
lines changed

1 file changed

+38
-7
lines changed

llmc/eval/eval_custom_generate_just_infer.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
1-
import glob
21
import json
32
import os
43

54
import torch
6-
from human_eval.data import stream_jsonl, write_jsonl
7-
from human_eval.evaluation import evaluate_functional_correctness
85
from loguru import logger
9-
from tqdm import tqdm
10-
11-
from .eval_base import BaseEval
126

137

148
class CustomGenerateJustInfer:
@@ -29,8 +23,45 @@ def eval(self, model, eval_pos=None):
2923
self.eval_cfg
3024
)
3125

32-
with open(os.path.join('custom_samples_ans.json'), 'w') as f:
26+
self.eval_answer(custom_samples_ans)
27+
28+
with open(os.path.join(self.config.save.save_path), 'w') as f:
3329
json.dump(custom_samples_ans, f, indent=4)
3430

3531
torch.cuda.empty_cache()
3632
return 'custom gen done.'
33+
34+
def eval_answer(self, data):
35+
T1V = 0
36+
T1V_T2V = 0
37+
38+
def create_pairs(lst):
39+
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
40+
41+
def check_acc(gt, answer, turn):
42+
if gt[turn].lower() in answer[turn].lower():
43+
return True
44+
return False
45+
46+
pair_data = create_pairs(data)
47+
48+
for idx, item in enumerate(pair_data):
49+
assert item[0]['image'] == item[1]['image']
50+
51+
pair1 = item[0]
52+
pair2 = item[1]
53+
54+
if check_acc(pair1['gt'], pair1['answer'], 0):
55+
T1V += 1
56+
if check_acc(pair2['gt'], pair2['answer'], 1):
57+
T1V_T2V += 1
58+
assert pair1['question'][0] == pair2['question'][1]
59+
60+
if check_acc(pair2['gt'], pair2['answer'], 0):
61+
T1V += 1
62+
if check_acc(pair1['gt'], pair1['answer'], 1):
63+
T1V_T2V += 1
64+
assert pair2['question'][0] == pair1['question'][1]
65+
66+
logger.info(f'CustomGenerateJustInfer T1V: {T1V}, T1V_T2V: {T1V_T2V}')
67+
logger.info(f'CustomGenerateJustInfer Possibility: {T1V_T2V / T1V}')

0 commit comments

Comments
 (0)