Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions llmc/eval/eval_custom_generate_just_infer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import glob
import json
import os

import torch
from human_eval.data import stream_jsonl, write_jsonl
from human_eval.evaluation import evaluate_functional_correctness
from loguru import logger
from tqdm import tqdm

from .eval_base import BaseEval


class CustomGenerateJustInfer:
Expand All @@ -29,8 +23,45 @@ def eval(self, model, eval_pos=None):
self.eval_cfg
)

with open(os.path.join('custom_samples_ans.json'), 'w') as f:
self.eval_answer(custom_samples_ans)

with open(os.path.join(self.config.save.save_path), 'w') as f:
json.dump(custom_samples_ans, f, indent=4)

torch.cuda.empty_cache()
return 'custom gen done.'

def eval_answer(self, data):
T1V = 0
T1V_T2V = 0
Comment on lines +35 to +36

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable names T1V and T1V_T2V are not descriptive and reduce code readability. Consider using more descriptive names that clearly state their purpose, such as turn1_valid_count and both_turns_valid_count.


def create_pairs(lst):
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list comprehension will raise an IndexError if lst has an odd number of elements. Validate the input to prevent unexpected crashes. Consider adding a check at the beginning of the function.

Suggested change
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
assert len(lst) % 2 == 0, "Input data for pairing must have an even number of elements."


def check_acc(gt, answer, turn):
if gt[turn].lower() in answer[turn].lower():
return True
return False
Comment on lines +42 to +44

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if-else block for returning a boolean value can be simplified to a single return statement for conciseness.

            return gt[turn].lower() in answer[turn].lower()


pair_data = create_pairs(data)

for idx, item in enumerate(pair_data):
assert item[0]['image'] == item[1]['image']

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for data validation can be risky, as assertions can be disabled with the -O flag in Python, which is common in production environments. For validating input data, it's generally safer to raise a ValueError or log an error and continue, to ensure the check is always performed.


pair1 = item[0]
pair2 = item[1]

if check_acc(pair1['gt'], pair1['answer'], 0):
T1V += 1
if check_acc(pair2['gt'], pair2['answer'], 1):
T1V_T2V += 1
assert pair1['question'][0] == pair2['question'][1]

if check_acc(pair2['gt'], pair2['answer'], 0):
T1V += 1
if check_acc(pair1['gt'], pair1['answer'], 1):
T1V_T2V += 1
assert pair2['question'][0] == pair1['question'][1]

logger.info(f'CustomGenerateJustInfer T1V: {T1V}, T1V_T2V: {T1V_T2V}')
logger.info(f'CustomGenerateJustInfer Possibility: {T1V_T2V / T1V}')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This division will raise a ZeroDivisionError if T1V is 0. Handle this case to prevent the script from crashing.

        logger.info(f'CustomGenerateJustInfer Possibility: {T1V_T2V / T1V if T1V > 0 else 0.0}')