1+ from dataflow .data import DataFlowDataset
2+ from dataflow .core import ScoreRecord
3+ from datasets import Dataset
4+
5+ class Reasoner ():
6+ def __init__ (self , args = None ):
7+ pass
8+
9+ def reason_func (self , dataset ):
10+ pass
11+
12+ def __call__ (self , dataset : DataFlowDataset ):
13+ pass
14+
15+ class ReasonerFilter (Reasoner ):
16+ def __init__ (self , args = None ):
17+ super ().__init__ ()
18+ self .data_type = "text"
19+ self .filter_name = "ReasonerFilter"
20+ self .args = args
21+
22+ api_args = args ['api_args' ]
23+ self .model_name = api_args ['model_name' ]
24+ self .api_url = api_args ['api_url' ]
25+ self .mode_test = api_args ['mode_test' ]
26+ def filter_func (self , dataset ):
27+ pass
28+
29+ def __call__ (self , dataset : DataFlowDataset ):
30+ """Processes the dataset using the reasoner"""
31+ init_len = len (dataset )
32+ score_record = ScoreRecord ()
33+ dataset .set_score_record (score_record )
34+ labels = self .filter_func (dataset )
35+
36+ if isinstance (dataset .dataset , Dataset ):
37+ def filter_by_labels (example , index ):
38+ return labels [index ] == 1
39+ dataset .dataset = dataset .dataset .filter (filter_by_labels , with_indices = True )
40+ filtered_dataset = dataset
41+ else :
42+ filtered_dataset = dataset .filter (labels )
43+
44+ print (f'Implemented { self .filter_name } . Data Number: { init_len } -> { len (filtered_dataset )} ' , flush = True )
45+ return filtered_dataset
0 commit comments