1- import glob
21import json
32import os
43
54import torch
6- from human_eval .data import stream_jsonl , write_jsonl
7- from human_eval .evaluation import evaluate_functional_correctness
85from loguru import logger
9- from tqdm import tqdm
10-
11- from .eval_base import BaseEval
126
137
148class CustomGenerateJustInfer :
@@ -29,8 +23,45 @@ def eval(self, model, eval_pos=None):
2923 self .eval_cfg
3024 )
3125
32- with open (os .path .join ('custom_samples_ans.json' ), 'w' ) as f :
26+ self .eval_answer (custom_samples_ans )
27+
28+ with open (os .path .join (self .config .save .save_path ), 'w' ) as f :
3329 json .dump (custom_samples_ans , f , indent = 4 )
3430
3531 torch .cuda .empty_cache ()
3632 return 'custom gen done.'
33+
34+ def eval_answer (self , data ):
35+ T1V = 0
36+ T1V_T2V = 0
37+
38+ def create_pairs (lst ):
39+ return [(lst [i ], lst [i + 1 ]) for i in range (0 , len (lst ), 2 )]
40+
41+ def check_acc (gt , answer , turn ):
42+ if gt [turn ].lower () in answer [turn ].lower ():
43+ return True
44+ return False
45+
46+ pair_data = create_pairs (data )
47+
48+ for idx , item in enumerate (pair_data ):
49+ assert item [0 ]['image' ] == item [1 ]['image' ]
50+
51+ pair1 = item [0 ]
52+ pair2 = item [1 ]
53+
54+ if check_acc (pair1 ['gt' ], pair1 ['answer' ], 0 ):
55+ T1V += 1
56+ if check_acc (pair2 ['gt' ], pair2 ['answer' ], 1 ):
57+ T1V_T2V += 1
58+ assert pair1 ['question' ][0 ] == pair2 ['question' ][1 ]
59+
60+ if check_acc (pair2 ['gt' ], pair2 ['answer' ], 0 ):
61+ T1V += 1
62+ if check_acc (pair1 ['gt' ], pair1 ['answer' ], 1 ):
63+ T1V_T2V += 1
64+ assert pair2 ['question' ][0 ] == pair1 ['question' ][1 ]
65+
66+ logger .info (f'CustomGenerateJustInfer T1V: { T1V } , T1V_T2V: { T1V_T2V } ' )
67+ logger .info (f'CustomGenerateJustInfer Possibility: { T1V_T2V / T1V } ' )
0 commit comments