Skip to content

Commit 8b5d570

Browse files
authored
Merge pull request #11 from TheRoadQaQ/main
reasoning_process
2 parents b56c39a + 1770d4b commit 8b5d570

File tree

10 files changed

+403
-2
lines changed

10 files changed

+403
-2
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
model_cache_path: '../ckpt' # Path to cache models
2+
dependencies: [text]
3+
save_path: "./processed.jsonl"
4+
5+
data:
6+
text:
7+
use_hf: False # Whether to use huggingface_dataset, if used, ignore the local data path below
8+
dataset_name: 'yahma/alpaca-cleaned'
9+
dataset_split: 'train'
10+
name: 'default'
11+
revision: null
12+
data_path: 'demos/reasoning_process/math_5_samples.json' # Local data path, supports json, jsonl, parquet formats
13+
formatter: "TextFormatter" # Data loader type
14+
keys: 'answer' # Key name to be processed, for sft data, it can be specified as ['instruction','input','output']
15+
16+
processors:
17+
AnswerFormatterFilter:
18+
type: "default"
19+
AnswerNgramFilter:
20+
min_score: 0.5
21+
max_score: 1.0
22+
ngrams: 5
23+
AnswerGroundTruthFilter:
24+
compare_method: exact # exact/math_verify/xverify
25+
AnswerTokenLengthFilter:
26+
max_answer_token_length: 1024
27+
tokenizer_dir: '../Qwen2.5-0.5B-Instruct'
28+
29+
30+

dataflow/process/text/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .filters import *
22
from .refiners import *
3-
from .deduplicators import *
3+
from .deduplicators import *
4+
from .reasoning import *
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import sys
2+
from dataflow.utils.registry import LazyLoader
3+
4+
_import_structure = {
5+
"AnswerGroundTruthFilter": ("dataflow/process/text/reasoning/answer_ground_truth_filter.py", "AnswerGroundTruthFilter"),
6+
"AnswerFormatterFilter": ("dataflow/process/text/reasoning/answer_formatter_filter.py", "AnswerFormatterFilter"),
7+
"AnswerNgramFilter": ("dataflow/process/text/reasoning/answer_ngram_filter.py", "AnswerNgramFilter"),
8+
"AnswerTokenLengthFilter": ("dataflow/process/text/reasoning/answer_token_length_filter.py", "AnswerTokenLengthFilter"),
9+
}
10+
11+
sys.modules[__name__] = LazyLoader(__name__, "dataflow/process/text/reasoning", _import_structure)
12+
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from dataflow.core import TextFilter
2+
import numpy as np
3+
from dataflow.utils.registry import PROCESSOR_REGISTRY
4+
import re
5+
6+
@PROCESSOR_REGISTRY.register()
7+
class AnswerFormatterFilter(TextFilter):
8+
def __init__(self, args_dict: dict):
9+
super().__init__(args_dict)
10+
self.filter_name = 'AnswerFormatterFilter'
11+
12+
def is_valid_answer(answer: str) -> bool:
13+
# start with "Solution:"
14+
if not answer.startswith("Solution:"):
15+
return False
16+
17+
# check that every step start with "→" or not
18+
#steps = answer.split("\n")
19+
#for step in steps:
20+
# if step.strip() and not step.strip().startswith("→"):
21+
# return False
22+
23+
# check final answer in \boxed{} or not
24+
if not re.search(r'\\boxed{.*}', answer):
25+
return False
26+
27+
return True
28+
29+
def filter_func(self, dataset):
30+
indexes = np.zeros(len(dataset)).astype(int)
31+
32+
for i, item in enumerate(dataset):
33+
answer = item['answer']
34+
if AnswerFormatterFilter.is_valid_answer(answer):
35+
indexes[i] = 1
36+
37+
return indexes
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
from dataflow.core import TextFilter
2+
import numpy as np
3+
from dataflow.utils.registry import PROCESSOR_REGISTRY
4+
#from math_verify import parse, verify, LatexExtractionConfig
5+
import pandas as pd
6+
from tqdm import tqdm
7+
import logging
8+
import re
9+
from word2number import w2n
10+
11+
# Helper Class for String Processing
12+
class StringProcessor:
13+
"""
14+
A class that encapsulates various string processing functions for mathematical expressions.
15+
"""
16+
17+
@staticmethod
18+
def _fix_fracs(string):
19+
"""
20+
Fixes fraction expressions in the string, ensuring they are properly formatted as \frac{a}{b}.
21+
"""
22+
substrs = string.split("\\frac")
23+
new_str = substrs[0]
24+
if len(substrs) > 1:
25+
for substr in substrs[1:]:
26+
new_str += "\\frac"
27+
if len(substr) > 0 and substr[0] == "{":
28+
new_str += substr
29+
else:
30+
if len(substr) >= 2:
31+
a, b = substr[0], substr[1]
32+
if b != "{":
33+
new_str += f"{{{a}}}{{{b}}}{substr[2:]}" if len(substr) > 2 else f"{{{a}}}{{{b}}}"
34+
else:
35+
new_str += f"{{{a}}}{b}{substr[2:]}" if len(substr) > 2 else f"{{{a}}}{b}"
36+
else:
37+
return string
38+
return new_str
39+
40+
@staticmethod
41+
def _fix_a_slash_b(string):
42+
"""
43+
Fixes cases where a fraction is represented as a simple division (e.g., a/b) and converts it to \frac{a}{b}.
44+
"""
45+
if len(string.split("/")) != 2:
46+
return string
47+
a, b = string.split("/")
48+
try:
49+
a, b = int(a) if "sqrt" not in a else a, int(b) if "sqrt" not in b else b
50+
assert string == f"{a}/{b}"
51+
return f"\\frac{{{a}}}{{{b}}}"
52+
except:
53+
return string
54+
55+
@staticmethod
56+
def _fix_sqrt(string):
57+
"""
58+
Ensures that square root expressions are properly formatted as \sqrt{...}.
59+
"""
60+
return re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
61+
62+
@staticmethod
63+
def convert_word_number(text: str) -> str:
64+
"""
65+
Converts a word representation of a number to a digit.
66+
"""
67+
try:
68+
return str(w2n.word_to_num(text))
69+
except:
70+
return text
71+
72+
73+
# Unit Text Class to Manage Unit Texts
74+
class UnitTextManager:
75+
"""
76+
A class that encapsulates unit text management to remove unwanted unit terms from strings.
77+
"""
78+
79+
def __init__(self):
80+
"""
81+
Initializes the unit texts and their plural forms.
82+
"""
83+
self.unit_texts = [
84+
"east", "degree", "mph", "kmph", "ft", "m sqaure", "m east", "sq m", "deg", "mile", "q .", "monkey", "prime",
85+
"ratio", "profit of rs", "rd", "o", "gm", "p . m", "lb", "tile", "per", "dm", "lt", "gain", "ab", "way", "west",
86+
"a .", "b .", "c .", "d .", "e .", "f .", "g .", "h .", "t", "a", "h", "no change", "men", "soldier", "pie", "bc",
87+
"excess", "st", "inches", "noon", "percent", "by", "gal", "kmh", "c", "acre", "rise", "a . m", "th", "π r 2", "sq",
88+
"mark", "l", "toy", "coin", "sq . m", "gallon", "° f", "profit", "minw", "yr", "women", "feet", "am", "pm", "hr",
89+
"cu cm", "square", "v â € ™", "are", "rupee", "rounds", "cubic", "cc", "mtr", "s", "ohm", "number", "kmph", "day",
90+
"hour", "minute", "min", "second", "man", "woman", "sec", "cube", "mt", "sq inch", "mp", "∏ cm ³", "hectare",
91+
"more", "sec", "unit", "cu . m", "cm 2", "rs .", "rs", "kg", "g", "month", "km", "m", "cm", "mm", "apple", "liter",
92+
"loss", "yard", "pure", "year", "increase", "decrease", "d", "less", "Surface", "litre", "pi sq m", "s .", "metre",
93+
"meter", "inch",
94+
]
95+
self.unit_texts.extend([t + "s" for t in self.unit_texts])
96+
97+
def clean_units(self, string: str):
98+
"""
99+
Cleans the string by removing unit terms from it.
100+
"""
101+
for unit_text in self.unit_texts:
102+
string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
103+
return string
104+
105+
106+
# Main String Processing Class
107+
class StringCleaner:
108+
"""
109+
A class responsible for cleaning and formatting strings in mathematical expressions.
110+
"""
111+
112+
def __init__(self, unit_manager: UnitTextManager):
113+
"""
114+
Initializes the StringCleaner class with a unit manager.
115+
"""
116+
self.unit_manager = unit_manager
117+
118+
def strip_string(self, string, skip_unit=False):
119+
"""
120+
Strips unwanted characters and units from the string.
121+
"""
122+
string = str(string).strip().replace("\n", "").rstrip(".").replace("\\!", "")
123+
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
124+
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string).replace("bmatrix", "pmatrix")
125+
string = string.replace("tfrac", "frac").replace("dfrac", "frac").replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge")
126+
string = string.replace("\\left", "").replace("\\right", "").replace("\\{", "{").replace("\\}", "}")
127+
128+
# Clean unit texts if needed
129+
if not skip_unit:
130+
string = self.unit_manager.clean_units(string)
131+
132+
string = string.replace("^{\\circ}", "").replace("^\\circ", "").replace("\\$", "").replace("$", "").replace("\\(", "").replace("\\)", "")
133+
string = StringProcessor.convert_word_number(string)
134+
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
135+
136+
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
137+
string = string.replace(key, "")
138+
139+
string = string.replace("\\emptyset", r"{}").replace("(-\\infty,\\infty)", "\\mathbb{R}")
140+
string = string.replace("%", "").replace(" .", " 0.").replace("{.", "{0.")
141+
142+
return string
143+
144+
145+
# Core Answer Extraction Logic Class
146+
class AnswerExtractor:
147+
"""
148+
A class responsible for extracting the final answer from a prediction string.
149+
"""
150+
151+
def __init__(self, string_cleaner: StringCleaner):
152+
"""
153+
Initializes the AnswerExtractor class with a string cleaner.
154+
"""
155+
self.string_cleaner = string_cleaner
156+
157+
def extract_answer(self, pred_str, data_name, use_last_number=True):
158+
"""
159+
Extracts the final answer from the prediction string, processing various formats.
160+
"""
161+
pred_str = pred_str.replace("\u043a\u0438", "")
162+
163+
# Handle special cases based on data_name or pattern
164+
if "final answer is $" in pred_str and "$. I hope" in pred_str:
165+
pred = pred_str.split("final answer is $", 1)[1].split("$. I hope", 1)[0].strip()
166+
elif "boxed" in pred_str:
167+
pred = self._extract_boxed_answer(pred_str)
168+
elif "he answer is" in pred_str:
169+
pred = pred_str.split("he answer is")[-1].strip()
170+
else:
171+
pred = self._get_last_number_answer(pred_str, use_last_number)
172+
173+
pred = self.string_cleaner.strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"])
174+
return pred
175+
176+
def _extract_boxed_answer(self, pred_str):
177+
"""
178+
Extracts answers enclosed in 'boxed' notation.
179+
"""
180+
ans = pred_str.split("boxed")[-1]
181+
if ans.startswith("{"):
182+
return self._extract_bracketed_answer(ans)
183+
else:
184+
return ans.split("$")[0].strip()
185+
186+
def _extract_bracketed_answer(self, ans):
187+
"""
188+
Handles answers that are enclosed within brackets.
189+
"""
190+
stack = 1
191+
result = ""
192+
for c in ans[1:]:
193+
if c == "{":
194+
stack += 1
195+
result += c
196+
elif c == "}":
197+
stack -= 1
198+
if stack == 0:
199+
break
200+
result += c
201+
else:
202+
result += c
203+
return result
204+
205+
def _get_last_number_answer(self, pred_str, use_last_number):
206+
"""
207+
Extracts the last number from the string if use_last_number is True.
208+
"""
209+
if use_last_number:
210+
pattern = "-?\d*\.?\d+"
211+
pred = re.findall(pattern, pred_str.replace(",", ""))
212+
return pred[-1] if pred else ""
213+
return ""
214+
215+
216+
@PROCESSOR_REGISTRY.register()
217+
class AnswerGroundTruthFilter(TextFilter):
218+
def __init__(self, args_dict: dict):
219+
super().__init__(args_dict)
220+
self.filter_name = 'AnswerGroundTruthFilter'
221+
unit_manager = UnitTextManager()
222+
string_cleaner = StringCleaner(unit_manager)
223+
self.answer_extractor = AnswerExtractor(string_cleaner)
224+
225+
def filter_func(self, dataset):
226+
indexes = np.zeros(len(dataset)).astype(int)
227+
for i in range(len(dataset)):
228+
final_answer = self.answer_extractor.extract_answer(dataset[i]['answer'], dataset[i].get('data_name', None))
229+
if 'ground_truth_answer'in dataset[i] and final_answer == dataset[i]['ground_truth_answer']:
230+
indexes[i] = 1
231+
return indexes
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from dataflow.core import TextFilter
2+
import numpy as np
3+
import re
4+
from dataflow.utils.registry import PROCESSOR_REGISTRY
5+
from dataflow.Eval.Text import NgramScorer
6+
7+
@PROCESSOR_REGISTRY.register()
8+
class AnswerNgramFilter(TextFilter):
9+
def __init__(self, args_dict: dict):
10+
super().__init__(args_dict)
11+
self.filter_name = 'AnswerNgramFilter'
12+
self.min_score = args_dict['min_score']
13+
self.max_score = args_dict['max_score']
14+
self.ngrams = args_dict['ngrams']
15+
16+
def filter_func(self, dataset):
17+
scores = []
18+
for sample in dataset:
19+
answer = sample['answer']
20+
content = answer.lower()
21+
content = re.sub(r'[^\w\s]', '', content)
22+
words = content.split()
23+
ngrams = [' '.join(words[i:i + self.ngrams]) for i in range(len(words) - (self.ngrams - 1))]
24+
unique_ngrams = set(ngrams)
25+
26+
total_ngrams = len(ngrams)
27+
unique_ngrams_count = len(unique_ngrams)
28+
29+
repetition_score = unique_ngrams_count / total_ngrams if total_ngrams > 0 else 0.0
30+
scores.append(repetition_score)
31+
32+
return np.array([self.min_score <= score <= self.max_score for score in scores]).astype(int)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from dataflow.core import TextFilter
2+
import numpy as np
3+
from dataflow.utils.registry import PROCESSOR_REGISTRY
4+
from transformers import AutoTokenizer
5+
6+
@PROCESSOR_REGISTRY.register()
7+
class AnswerTokenLengthFilter(TextFilter):
8+
def __init__(self, args_dict: dict):
9+
super().__init__(args_dict)
10+
self.filter_name = 'AnswerTokenLengthFilter'
11+
self.max_answer_token_length = args_dict['max_answer_token_length']
12+
self.tokenizer = AutoTokenizer.from_pretrained(args_dict['tokenizer_dir'])
13+
14+
def filter_func(self, dataset):
15+
def get_token_count(input_string):
16+
# 使用 tokenizer 对字符串进行编码,并获取 token 数目
17+
tokens = self.tokenizer.encode(input_string, add_special_tokens=False)
18+
return len(tokens)
19+
20+
return np.array([get_token_count(item['answer']) <= self.max_answer_token_length for item in dataset]).astype(int)

dataflow/utils/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def get(self, name):
7474
raise e
7575
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
7676
elif self._name == "processor":
77-
for x in ['text.refiners', 'text.filters', 'text.deduplicators', 'image.filters', 'image.deduplicators', 'video.filters']:
77+
for x in ['text.refiners', 'text.filters', 'text.deduplicators', 'text.reasoning','image.filters', 'image.deduplicators', 'video.filters']:
7878
# for x in ['image.filters', 'image.refiners']:
7979
module_path = "dataflow.process." + x
8080
try:

0 commit comments

Comments
 (0)