Skip to content

Commit 88d8bed

Browse files
authored
Merge pull request #12 from scuuy/main
add reasoners module in process, do necessary registry and add some demos.
2 parents 8b5d570 + f7b6093 commit 88d8bed

19 files changed

+268
-18
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ configs/process/experiments/
2222
**/data
2323
**/ckpt
2424
tmp.*
25+
configs/process/text_process_reasoner.yaml

configs/process/text_process.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,4 +299,12 @@ processors:
299299
device: 'cuda:0'
300300
model_name: 'hkust-nlp/deita-quality-scorer'
301301
max_length: 512
302-
302+
MathProblemFilter:
303+
min_score: 0
304+
max_score: 1
305+
api_args:
306+
api_url: 'The URL of api, default using general method'
307+
api_key: 'Your Key'
308+
model_name: 'gpt-4o'
309+
mode_test: True
310+

configs/process/text_process_reasoning.yaml renamed to configs/process/text_process_reasoner_ansfilter.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ data:
99
dataset_split: 'train'
1010
name: 'default'
1111
revision: null
12-
data_path: 'demos/reasoning_process/math_5_samples.json' # Local data path, supports json, jsonl, parquet formats
12+
data_path: 'demos/text_process/reasoners/math_5_samples.json' # Local data path, supports json, jsonl, parquet formats
1313
formatter: "TextFormatter" # Data loader type
1414
keys: 'answer' # Key name to be processed, for sft data, it can be specified as ['instruction','input','output']
1515

dataflow/core/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .process.filter import Filter, ImageFilter, ImageTextFilter, VideoFilter, TextFilter, VideoTextFilter
33
from .process.refiner import Refiner, TextRefiner
44
from .process.deduplicator import Deduplicator, TextDeduplicator, ImageDeduplicator
5+
from .process.reasoner import ReasonerFilter
56

67
__all__ = [
78
'Scorer',
@@ -20,5 +21,6 @@
2021
'Refiner',
2122
'TextRefiner',
2223
'Deduplicator',
23-
'TextDeduplicator'
24+
'TextDeduplicator',
25+
'ReasonerFilter'
2426
]

dataflow/core/process/reasoner.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from dataflow.data import DataFlowDataset
2+
from dataflow.core import ScoreRecord
3+
from datasets import Dataset
4+
5+
class Reasoner():
6+
def __init__(self, args=None):
7+
pass
8+
9+
def reason_func(self, dataset):
10+
pass
11+
12+
def __call__(self, dataset: DataFlowDataset):
13+
pass
14+
15+
class ReasonerFilter(Reasoner):
16+
def __init__(self, args=None):
17+
super().__init__()
18+
self.data_type = "text"
19+
self.filter_name = "ReasonerFilter"
20+
self.args = args
21+
22+
api_args = args['api_args']
23+
self.model_name = api_args['model_name']
24+
self.api_url = api_args['api_url']
25+
self.mode_test = api_args['mode_test']
26+
def filter_func(self, dataset):
27+
pass
28+
29+
def __call__(self, dataset: DataFlowDataset):
30+
"""Processes the dataset using the reasoner"""
31+
init_len = len(dataset)
32+
score_record = ScoreRecord()
33+
dataset.set_score_record(score_record)
34+
labels = self.filter_func(dataset)
35+
36+
if isinstance(dataset.dataset, Dataset):
37+
def filter_by_labels(example, index):
38+
return labels[index] == 1
39+
dataset.dataset = dataset.dataset.filter(filter_by_labels, with_indices=True)
40+
filtered_dataset = dataset
41+
else:
42+
filtered_dataset = dataset.filter(labels)
43+
44+
print(f'Implemented {self.filter_name}. Data Number: {init_len} -> {len(filtered_dataset)}', flush=True)
45+
return filtered_dataset

dataflow/process/text/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .filters import *
22
from .refiners import *
33
from .deduplicators import *
4-
from .reasoning import *
4+
from .reasoners import *
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import sys
2+
from dataflow.utils.registry import LazyLoader
3+
4+
_import_structure = {
5+
"MathProblemFilter": ("dataflow/process/text/reasoners/math_problem_filter.py", "MathProblemFilter"),
6+
"AnswerGroundTruthFilter": ("dataflow/process/text/reasoners/answer_ground_truth_filter.py", "AnswerGroundTruthFilter"),
7+
"AnswerFormatterFilter": ("dataflow/process/text/reasoners/answer_formatter_filter.py", "AnswerFormatterFilter"),
8+
"AnswerNgramFilter": ("dataflow/process/text/reasoners/answer_ngram_filter.py", "AnswerNgramFilter"),
9+
"AnswerTokenLengthFilter": ("dataflow/process/text/reasoners/answer_token_length_filter.py", "AnswerTokenLengthFilter"),
10+
}
11+
12+
sys.modules[__name__] = LazyLoader(__name__, "dataflow/process/text/reasoners", _import_structure)

dataflow/process/text/reasoning/answer_formatter_filter.py renamed to dataflow/process/text/reasoners/answer_formatter_filter.py

File renamed without changes.

dataflow/process/text/reasoning/answer_ground_truth_filter.py renamed to dataflow/process/text/reasoners/answer_ground_truth_filter.py

File renamed without changes.

dataflow/process/text/reasoning/answer_ngram_filter.py renamed to dataflow/process/text/reasoners/answer_ngram_filter.py

File renamed without changes.

0 commit comments

Comments
 (0)