Skip to content

Commit 5cdc4a7

Browse files
authored
Merge pull request #14 from TheRoadQaQ/main
Reasoner process, Answer method
2 parents 88d8bed + 50531a3 commit 5cdc4a7

File tree

7 files changed

+42
-35
lines changed

7 files changed

+42
-35
lines changed

configs/process/text_process_reasoner_ansfilter.yaml

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
model_cache_path: '../ckpt' # Path to cache models
22
dependencies: [text]
3-
save_path: "./processed.jsonl"
3+
save_path: "../dataflow-develop/processed.jsonl"
44

55
data:
66
text:
@@ -9,21 +9,20 @@ data:
99
dataset_split: 'train'
1010
name: 'default'
1111
revision: null
12-
data_path: 'demos/text_process/reasoners/math_5_samples.json' # Local data path, supports json, jsonl, parquet formats
12+
data_path: './demos/text_process/reasoners/math_5_samples.json' # Local data path, supports json, jsonl, parquet formats
1313
formatter: "TextFormatter" # Data loader type
1414
keys: 'answer' # Key name to be processed, for sft data, it can be specified as ['instruction','input','output']
1515

1616
processors:
17-
AnswerFormatterFilter:
18-
type: "default"
17+
AnswerFormatterFilter: {}
1918
AnswerNgramFilter:
20-
min_score: 0.5
19+
min_score: 0.1
2120
max_score: 1.0
2221
ngrams: 5
2322
AnswerGroundTruthFilter:
24-
compare_method: exact # exact/math_verify/xverify
23+
compare_method: math_verify # exact or math_verify
2524
AnswerTokenLengthFilter:
26-
max_answer_token_length: 1024
25+
max_answer_token_length: 512
2726
tokenizer_dir: '../Qwen2.5-0.5B-Instruct'
2827

2928

dataflow/core/process/reasoner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ def __init__(self, args=None):
1919
self.filter_name = "ReasonerFilter"
2020
self.args = args
2121

22-
api_args = args['api_args']
23-
self.model_name = api_args['model_name']
24-
self.api_url = api_args['api_url']
25-
self.mode_test = api_args['mode_test']
22+
api_args = args.get('api_args', None)
23+
if api_args is not None:
24+
self.model_name = api_args['model_name']
25+
self.api_url = api_args['api_url']
26+
self.mode_test = api_args['mode_test']
2627
def filter_func(self, dataset):
2728
pass
2829

dataflow/process/text/reasoners/answer_formatter_filter.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
1-
from dataflow.core import TextFilter
1+
from dataflow.core import TextFilter, ReasonerFilter
22
import numpy as np
33
from dataflow.utils.registry import PROCESSOR_REGISTRY
44
import re
55

66
@PROCESSOR_REGISTRY.register()
7-
class AnswerFormatterFilter(TextFilter):
7+
class AnswerFormatterFilter(ReasonerFilter):
88
def __init__(self, args_dict: dict):
99
super().__init__(args_dict)
1010
self.filter_name = 'AnswerFormatterFilter'
1111

1212
def is_valid_answer(answer: str) -> bool:
13-
# start with "Solution:"
14-
if not answer.startswith("Solution:"):
15-
return False
16-
17-
# check that every step start with "→" or not
18-
#steps = answer.split("\n")
19-
#for step in steps:
20-
# if step.strip() and not step.strip().startswith("→"):
21-
# return False
22-
2313
# check final answer in \boxed{} or not
2414
if not re.search(r'\\boxed{.*}', answer):
2515
return False

dataflow/process/text/reasoners/answer_ground_truth_filter.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from dataflow.core import TextFilter
1+
from dataflow.core import ReasonerFilter
2+
from math_verify import parse, verify, LatexExtractionConfig
23
import numpy as np
34
from dataflow.utils.registry import PROCESSOR_REGISTRY
45
#from math_verify import parse, verify, LatexExtractionConfig
@@ -214,18 +215,35 @@ def _get_last_number_answer(self, pred_str, use_last_number):
214215

215216

216217
@PROCESSOR_REGISTRY.register()
217-
class AnswerGroundTruthFilter(TextFilter):
218+
class AnswerGroundTruthFilter(ReasonerFilter):
218219
def __init__(self, args_dict: dict):
219220
super().__init__(args_dict)
220221
self.filter_name = 'AnswerGroundTruthFilter'
221222
unit_manager = UnitTextManager()
222223
string_cleaner = StringCleaner(unit_manager)
223224
self.answer_extractor = AnswerExtractor(string_cleaner)
224225

226+
name2compare = {
227+
'exact': self.exact_compare,
228+
'math_verify': self.math_verify_compare
229+
}
230+
231+
self.compare = name2compare[args_dict.get('compare_method', 'exact')]
232+
233+
def exact_compare(self, answer, ground_truth):
234+
return answer == ground_truth
235+
236+
def math_verify_compare(self, answer, ground_truth):
237+
try:
238+
return verify(parse(ground_truth), parse(answer))
239+
except:
240+
return False
241+
225242
def filter_func(self, dataset):
226243
indexes = np.zeros(len(dataset)).astype(int)
227244
for i in range(len(dataset)):
228245
final_answer = self.answer_extractor.extract_answer(dataset[i]['answer'], dataset[i].get('data_name', None))
229-
if 'ground_truth_answer'in dataset[i] and final_answer == dataset[i]['ground_truth_answer']:
230-
indexes[i] = 1
246+
if 'ground_truth_answer' in dataset[i]:
247+
if self.compare(final_answer, dataset[i]['ground_truth_answer']):
248+
indexes[i] = 1
231249
return indexes

dataflow/process/text/reasoners/answer_ngram_filter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from dataflow.core import TextFilter
1+
from dataflow.core import ReasonerFilter
22
import numpy as np
33
import re
44
from dataflow.utils.registry import PROCESSOR_REGISTRY
55
from dataflow.Eval.Text import NgramScorer
66

77
@PROCESSOR_REGISTRY.register()
8-
class AnswerNgramFilter(TextFilter):
8+
class AnswerNgramFilter(ReasonerFilter):
99
def __init__(self, args_dict: dict):
1010
super().__init__(args_dict)
1111
self.filter_name = 'AnswerNgramFilter'

dataflow/process/text/reasoners/answer_token_length_filter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from dataflow.core import TextFilter
1+
from dataflow.core import ReasonerFilter
22
import numpy as np
33
from dataflow.utils.registry import PROCESSOR_REGISTRY
44
from transformers import AutoTokenizer
55

66
@PROCESSOR_REGISTRY.register()
7-
class AnswerTokenLengthFilter(TextFilter):
7+
class AnswerTokenLengthFilter(ReasonerFilter):
88
def __init__(self, args_dict: dict):
99
super().__init__(args_dict)
1010
self.filter_name = 'AnswerTokenLengthFilter'
@@ -13,7 +13,6 @@ def __init__(self, args_dict: dict):
1313

1414
def filter_func(self, dataset):
1515
def get_token_count(input_string):
16-
# 使用 tokenizer 对字符串进行编码,并获取 token 数目
1716
tokens = self.tokenizer.encode(input_string, add_special_tokens=False)
1817
return len(tokens)
1918

demos/text_process/reasoners/math_5_samples.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
},
88
{
99
"answer":
10-
"1. Find critical points:\n → f'(x) = 3x² - 6x\n → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2\n \n 2. Evaluate function at critical points and endpoints:\n → f(-1) = (-1)^3 - 3(-1)^2 + 4 = -1 -3 +4 = 0.0000\n → f(0) = 0³ - 3(0)² +4 = 4.0000\n → f(2) = 8 - 12 +4 = 0.0000\n → f(3) = 27 - 27 +4 = 4.0000\n \n 3. Compare values:\n → Minimum occurs at x=-1 and x=2\n \n Verification:\n → Second derivative test: f''(x) = 6x-6\n → f''(-1) = -12 < 0 (local max)\n → f''(2) = 6 > 0 (local min)\n \n \\boxed{0}"
10+
"Solution:\n 1. Find critical points:\n → f'(x) = 3x² - 6x\n → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2\n \n 2. Evaluate function at critical points and endpoints:\n → f(-1) = (-1)^3 - 3(-1)^2 + 4 = -1 -3 +4 = 0.0000\n → f(0) = 0³ - 3(0)² +4 = 4.0000\n → f(2) = 8 - 12 +4 = 0.0000\n → f(3) = 27 - 27 +4 = 4.0000\n \n 3. Compare values:\n → Minimum occurs at x=-1 and x=2\n \n Verification:\n → Second derivative test: f''(x) = 6x-6\n → f''(-1) = -12 < 0 (local max)\n → f''(2) = 6 > 0 (local min)\n \n \\boxed{0.5}"
1111
,
12-
"ground_truth_answer": "0"
12+
"ground_truth_answer": "1/2"
1313
},
1414
{
1515
"answer":
@@ -30,7 +30,7 @@
3030
},
3131
{
3232
"answer":
33-
"Solution:\n 1. Find critical points:\n → f'(x) = 3x² - 6x\n 1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2\n \n 2. Evaluate function at critical points and endpoints:\n → f(-1) = (-1)^3 - 3(-1)^2 + 4 = -1 -3 +4 = 0.0000\n → f(0) = 0³ - 3(0)² +4 = 4.0000\n → f(2) = 8 - 12 +4 = 0.0000\n → f(3) = 27 - 27 +4 = 4.0000\n \n 3. Compare values:\n → Minimum occurs at x=-1 and x=2\n \n Verification:\n → Second derivative test: f''(x) = 6x-6\n → f''(-1) = -12 < 0 (local max)\n → f''(2) = 6 > 0 (local min)\n \n \\boxed{0}"
33+
"Solution:\n 1. Find critical points:\n → f'(x) = 3x² - 6x\n 1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n1. Find critical points:\n → f'(x) = 3x² - 6x\n → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2 → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2 → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2 → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2 → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2 → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2 → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2 → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2 → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2 → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2 → Set derivative to zero: 3x(x-2) = 0 ⇒ x=0, x=2\n \n 2. Evaluate function at critical points and endpoints:\n → f(-1) = (-1)^3 - 3(-1)^2 + 4 = -1 -3 +4 = 0.0000\n → f(0) = 0³ - 3(0)² +4 = 4.0000\n → f(2) = 8 - 12 +4 = 0.0000\n → f(3) = 27 - 27 +4 = 4.0000\n \n 3. Compare values:\n → Minimum occurs at x=-1 and x=2\n \n Verification:\n → Second derivative test: f''(x) = 6x-6\n → f''(-1) = -12 < 0 (local max)\n → f''(2) = 6 > 0 (local min)\n \n \\boxed{0}"
3434
,
3535
"ground_truth_answer": "0"
3636
}

0 commit comments

Comments
 (0)