Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions dataflow/core/process/reasoner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ def __init__(self, args=None):
self.filter_name = "ReasonerFilter"
self.args = args

api_args = args['api_args']
self.model_name = api_args['model_name']
self.api_url = api_args['api_url']
self.mode_test = api_args['mode_test']
api_args = args.get('api_args', None)
if api_args is not None:
self.model_name = api_args['model_name']
self.api_url = api_args['api_url']
self.mode_test = api_args['mode_test']
def filter_func(self, dataset):
pass

Expand Down
14 changes: 2 additions & 12 deletions dataflow/process/text/reasoners/answer_formatter_filter.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
from dataflow.core import TextFilter
from dataflow.core import TextFilter, ReasonerFilter
import numpy as np
from dataflow.utils.registry import PROCESSOR_REGISTRY
import re

@PROCESSOR_REGISTRY.register()
class AnswerFormatterFilter(TextFilter):
class AnswerFormatterFilter(ReasonerFilter):
def __init__(self, args_dict: dict):
super().__init__(args_dict)
self.filter_name = 'AnswerFormatterFilter'

def is_valid_answer(answer: str) -> bool:
# start with "Solution:"
if not answer.startswith("Solution:"):
return False

# check that every step start with "→" or not
#steps = answer.split("\n")
#for step in steps:
# if step.strip() and not step.strip().startswith("→"):
# return False

# check final answer in \boxed{} or not
if not re.search(r'\\boxed{.*}', answer):
return False
Expand Down
26 changes: 22 additions & 4 deletions dataflow/process/text/reasoners/answer_ground_truth_filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataflow.core import TextFilter
from dataflow.core import ReasonerFilter
from math_verify import parse, verify, LatexExtractionConfig
import numpy as np
from dataflow.utils.registry import PROCESSOR_REGISTRY
#from math_verify import parse, verify, LatexExtractionConfig
Expand Down Expand Up @@ -214,18 +215,35 @@ def _get_last_number_answer(self, pred_str, use_last_number):


@PROCESSOR_REGISTRY.register()
class AnswerGroundTruthFilter(TextFilter):
class AnswerGroundTruthFilter(ReasonerFilter):
def __init__(self, args_dict: dict):
super().__init__(args_dict)
self.filter_name = 'AnswerGroundTruthFilter'
unit_manager = UnitTextManager()
string_cleaner = StringCleaner(unit_manager)
self.answer_extractor = AnswerExtractor(string_cleaner)

name2compare = {
'exact': self.exact_compare,
'math_verify': self.math_verify_compare
}

self.compare = name2compare[args_dict.get('compare_method', 'exact')]

def exact_compare(self, answer, ground_truth):
return answer == ground_truth

def math_verify_compare(self, answer, ground_truth):
try:
return verify(parse(ground_truth), parse(answer))
except:
return False

def filter_func(self, dataset):
indexes = np.zeros(len(dataset)).astype(int)
for i in range(len(dataset)):
final_answer = self.answer_extractor.extract_answer(dataset[i]['answer'], dataset[i].get('data_name', None))
if 'ground_truth_answer'in dataset[i] and final_answer == dataset[i]['ground_truth_answer']:
indexes[i] = 1
if 'ground_truth_answer' in dataset[i]:
if self.compare(final_answer, dataset[i]['ground_truth_answer']):
indexes[i] = 1
return indexes
4 changes: 2 additions & 2 deletions dataflow/process/text/reasoners/answer_ngram_filter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from dataflow.core import TextFilter
from dataflow.core import ReasonerFilter
import numpy as np
import re
from dataflow.utils.registry import PROCESSOR_REGISTRY
from dataflow.Eval.Text import NgramScorer

@PROCESSOR_REGISTRY.register()
class AnswerNgramFilter(TextFilter):
class AnswerNgramFilter(ReasonerFilter):
def __init__(self, args_dict: dict):
super().__init__(args_dict)
self.filter_name = 'AnswerNgramFilter'
Expand Down
5 changes: 2 additions & 3 deletions dataflow/process/text/reasoners/answer_token_length_filter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataflow.core import TextFilter
from dataflow.core import ReasonerFilter
import numpy as np
from dataflow.utils.registry import PROCESSOR_REGISTRY
from transformers import AutoTokenizer

@PROCESSOR_REGISTRY.register()
class AnswerTokenLengthFilter(TextFilter):
class AnswerTokenLengthFilter(ReasonerFilter):
def __init__(self, args_dict: dict):
super().__init__(args_dict)
self.filter_name = 'AnswerTokenLengthFilter'
Expand All @@ -13,7 +13,6 @@ def __init__(self, args_dict: dict):

def filter_func(self, dataset):
def get_token_count(input_string):
# 使用 tokenizer 对字符串进行编码,并获取 token 数目
tokens = self.tokenizer.encode(input_string, add_special_tokens=False)
return len(tokens)

Expand Down
Loading