Skip to content

Commit 86d7a94

Browse files
Update the way of saving data. (#63)
* feat: delete good / bad info list, change the way of writing data. * feat: delete interval_size in InputArgs.py * feat: fix raising error when not save data. * feat: add write test. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent aa45b7a commit 86d7a94

File tree

6 files changed

+75
-108
lines changed

6 files changed

+75
-108
lines changed

dingo/exec/local.py

Lines changed: 41 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ def __init__(self, input_args: InputArgs):
2727
self.input_args: InputArgs = input_args
2828
self.llm: Optional[BaseLLM] = None
2929
self.summary: SummaryModel = SummaryModel()
30-
self.bad_info_list: List[ResultInfo] = []
31-
self.good_info_list: List[ResultInfo] = []
3230

3331
def load_data(self) -> Generator[MetaData, None, None]:
3432
"""
@@ -68,19 +66,11 @@ def execute(self) -> List[SummaryModel]:
6866
eval_group=group_name,
6967
input_path=input_path,
7068
output_path=output_path if self.input_args.save_data else '',
71-
create_time=create_time,
72-
score=0,
73-
num_good=0,
74-
num_bad=0,
75-
total=0,
76-
type_ratio={},
77-
name_ratio={}
69+
create_time=create_time
7870
)
7971
self.evaluate()
8072
self.summary = self.summarize(self.summary)
81-
self.summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
82-
if self.input_args.save_data:
83-
self.save_data(output_path, self.input_args, self.bad_info_list, self.good_info_list, self.summary)
73+
self.write_summary(self.summary.output_path, self.input_args, self.summary)
8474

8575
return [self.summary]
8676

@@ -98,8 +88,6 @@ def evaluate(self):
9888
pbar = tqdm(total=None, unit='items')
9989

10090
def process_batch(batch: List):
101-
save_flag = False
102-
10391
futures=[]
10492
for group_type, group in Model.get_group(self.input_args.eval_group).items():
10593
if group_type == 'rule':
@@ -111,46 +99,19 @@ def process_batch(batch: List):
11199

112100
for future in concurrent.futures.as_completed(futures):
113101
result_info = future.result()
114-
# calculate summary ratio
102+
for t in result_info.type_list:
103+
self.summary.type_ratio[t] += 1
104+
for n in result_info.name_list:
105+
self.summary.name_ratio[n] += 1
115106
if result_info.error_status:
116-
self.bad_info_list.append(result_info)
117107
self.summary.num_bad += 1
118-
for t in result_info.type_list:
119-
if t not in self.summary.type_ratio:
120-
self.summary.type_ratio[t] = 1
121-
else:
122-
self.summary.type_ratio[t] += 1
123-
for n in result_info.name_list:
124-
if n not in self.summary.name_ratio:
125-
self.summary.name_ratio[n] = 1
126-
else:
127-
self.summary.name_ratio[n] += 1
128108
else:
129-
if self.input_args.save_correct:
130-
self.good_info_list.append(result_info)
131-
for t in result_info.type_list:
132-
if t not in self.summary.type_ratio:
133-
self.summary.type_ratio[t] = 1
134-
else:
135-
self.summary.type_ratio[t] += 1
136-
for n in result_info.name_list:
137-
if n not in self.summary.name_ratio:
138-
self.summary.name_ratio[n] = 1
139-
else:
140-
self.summary.name_ratio[n] += 1
109+
self.summary.num_good += 1
141110
self.summary.total += 1
142-
if self.summary.total % self.input_args.interval_size == 0:
143-
save_flag = True
111+
112+
self.write_single_data(self.summary.output_path, self.input_args, result_info)
144113
pbar.update()
145-
# save data in file
146-
if self.input_args.save_data:
147-
if save_flag:
148-
tmp_summary = self.summarize(self.summary)
149-
tmp_summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
150-
tmp_output_path = self.summary.output_path
151-
self.save_data(tmp_output_path, self.input_args, self.bad_info_list, self.good_info_list, tmp_summary)
152-
self.bad_info_list = []
153-
self.good_info_list = []
114+
self.write_summary(self.summary.output_path, self.input_args, self.summarize(self.summary))
154115
while True:
155116
batch = list(itertools.islice(data_iter, self.input_args.batch_size))
156117
if not batch:
@@ -270,9 +231,9 @@ def evaluate_prompt(self, group: List[BasePrompt], d: MetaData) -> ResultInfo:
270231

271232
def summarize(self, summary: SummaryModel) -> SummaryModel:
272233
new_summary = copy.deepcopy(summary)
234+
new_summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
273235
if new_summary.total == 0:
274236
return new_summary
275-
new_summary.num_good = new_summary.total - new_summary.num_bad
276237
new_summary.score = round(new_summary.num_good / new_summary.total * 100, 2)
277238
for t in new_summary.type_ratio:
278239
new_summary.type_ratio[t] = round(new_summary.type_ratio[t] / new_summary.total, 6)
@@ -282,52 +243,38 @@ def summarize(self, summary: SummaryModel) -> SummaryModel:
282243
new_summary.name_ratio = dict(sorted(new_summary.name_ratio.items()))
283244
return new_summary
284245

285-
def get_summary(self):
286-
return self.summary
246+
def write_single_data(self, path: str, input_args: InputArgs, result_info: ResultInfo):
247+
if not input_args.save_data:
248+
return
287249

288-
def get_bad_info_list(self):
289-
return self.bad_info_list
290-
291-
def get_good_info_list(self):
292-
return self.good_info_list
250+
if not input_args.save_correct and not result_info.error_status:
251+
return
293252

294-
def save_data(
295-
self,
296-
path: str,
297-
input_args: InputArgs,
298-
bad_info_list: List[ResultInfo],
299-
good_info_list: List[ResultInfo],
300-
summary: SummaryModel,
301-
):
302-
for result_info in bad_info_list:
303-
for new_name in result_info.name_list:
304-
t = str(new_name).split('-')[0]
305-
n = str(new_name).split('-')[1]
306-
p_t = os.path.join(path, t)
307-
if not os.path.exists(p_t):
308-
os.makedirs(p_t)
309-
f_n = os.path.join(path, t, n) + ".jsonl"
310-
with open(f_n, 'a', encoding='utf-8') as f:
311-
if input_args.save_raw:
312-
str_json = json.dumps(result_info.to_raw_dict(), ensure_ascii=False)
313-
else:
314-
str_json = json.dumps(result_info.to_dict(), ensure_ascii=False)
315-
f.write(str_json + '\n')
316-
if input_args.save_correct:
317-
for result_info in good_info_list:
318-
for new_name in result_info.name_list:
319-
t = str(new_name).split('-')[0]
320-
n = str(new_name).split('-')[1]
321-
p_t = os.path.join(path, t)
322-
if not os.path.exists(p_t):
323-
os.makedirs(p_t)
324-
f_n = os.path.join(path, t, n) + ".jsonl"
325-
with open(f_n, 'a', encoding='utf-8') as f:
326-
if input_args.save_raw:
327-
str_json = json.dumps(result_info.to_raw_dict(), ensure_ascii=False)
328-
else:
329-
str_json = json.dumps(result_info.to_dict(), ensure_ascii=False)
330-
f.write(str_json + '\n')
253+
for new_name in result_info.name_list:
254+
t = str(new_name).split('-')[0]
255+
n = str(new_name).split('-')[1]
256+
p_t = os.path.join(path, t)
257+
if not os.path.exists(p_t):
258+
os.makedirs(p_t)
259+
f_n = os.path.join(path, t, n) + ".jsonl"
260+
with open(f_n, 'a', encoding='utf-8') as f:
261+
if input_args.save_raw:
262+
str_json = json.dumps(result_info.to_raw_dict(), ensure_ascii=False)
263+
else:
264+
str_json = json.dumps(result_info.to_dict(), ensure_ascii=False)
265+
f.write(str_json + '\n')
331266

267+
def write_summary(self, path: str, input_args: InputArgs, summary: SummaryModel):
268+
if not input_args.save_data:
269+
return
332270
with open(path + '/summary.json', 'w', encoding='utf-8') as f:
333271
json.dump(summary.to_dict(), f, indent=4, ensure_ascii=False)
272+
273+
def get_summary(self):
274+
pass
275+
276+
def get_bad_info_list(self):
277+
pass
278+
279+
def get_good_info_list(self):
280+
pass

dingo/io/input/InputArgs.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ class InputArgs(BaseModel):
2424
# Resume settings
2525
start_index: int = 0
2626
end_index: int = -1
27-
interval_size: int = 1000
2827

2928
# Concurrent settings
3029
max_workers: int = 1
@@ -89,10 +88,6 @@ def check_args(self):
8988
if self.end_index >= 0 and self.end_index < self.start_index:
9089
raise ValueError("if end_index is non negative, end_index must be greater than start_index")
9190

92-
# check interval size
93-
if self.interval_size <= 0:
94-
raise ValueError("interval_size must be positive.")
95-
9691
# check max workers
9792
if self.max_workers <= 0:
9893
raise ValueError("max_workers must be a positive integer.")

dingo/io/output/SummaryModel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from collections import defaultdict
12
from typing import Dict, List
23

3-
from pydantic import BaseModel
4+
from pydantic import BaseModel, Field
45

56

67
class SummaryModel(BaseModel):
@@ -15,8 +16,8 @@ class SummaryModel(BaseModel):
1516
num_good: int = 0
1617
num_bad: int = 0
1718
total: int = 0
18-
type_ratio: Dict[str, float] = {}
19-
name_ratio: Dict[str, float] = {}
19+
type_ratio: Dict[str, int] = Field(default_factory=lambda: defaultdict(int))
20+
name_ratio: Dict[str, int] = Field(default_factory=lambda: defaultdict(int))
2021

2122
def to_dict(self):
2223
return {

dingo/run/cli.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ def parse_args():
3030
default=None, help="The number of data start to check.")
3131
parser.add_argument("--end_index", type=int,
3232
default=None, help="The number of data end to check.")
33-
parser.add_argument("--interval_size", type=int,
34-
default=None, help="The number of size to save while checking.")
3533
parser.add_argument("--max_workers", type=int,
3634
default=None, help="The number of max workers to concurrent check. ")
3735
parser.add_argument("--batch_size", type=int,
@@ -112,8 +110,6 @@ def parse_args():
112110
input_data['start_index'] = args.start_index
113111
if args.end_index:
114112
input_data['end_index'] = args.end_index
115-
if args.interval_size:
116-
input_data['interval_size'] = args.interval_size
117113
if args.max_workers:
118114
input_data['max_workers'] = args.max_workers
119115
if args.batch_size:

docs/config.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
| --save_raw | bool | False | No | whether save raw data. |
1818
| --start_index | int | 0 | No | the number of data start to check. |
1919
| --end_index | int | -1 | No | the number of data end to check. if it's negative, include the data from start_index to end. |
20-
| --interval_size | int | 1000 | No | the number of size to save while checking. |
2120
| --max_workers | int | 1 | No | the number of max workers to concurrent check. |
2221
| --batch_size | int | 1 | No | the number of max data for concurrent check. |
2322
| --dataset | str | "hugging_face" | Yes | dataset type, in ['hugging_face', 'local'] |
@@ -46,7 +45,6 @@
4645
| save_raw | bool | False | No | whether save raw data. |
4746
| start_index | int | 0 | No | the number of data start to check. |
4847
| end_index | int | -1 | No | the number of data end to check. if it's negative, include the data from start_index to end. |
49-
| interval_size | int | 1000 | No | the number of size to save while checking. |
5048
| max_workers | int | 1 | No | the number of max workers to concurrent check. |
5149
| batch_size | int | 1 | No | the number of max data for concurrent check. |
5250
| dataset | str | "hugging_face" | Yes | dataset type, in ['hugging_face', 'local'] |

test/scripts/test_write.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
import shutil
3+
4+
import pytest
5+
from dingo.exec import Executor
6+
from dingo.io import InputArgs
7+
8+
9+
class TestWrite:
10+
def test_write_local_jsonl(self):
11+
input_args = InputArgs(**{
12+
"eval_group": "qa_standard_v1",
13+
"input_path": "../data/test_local_jsonl.jsonl",
14+
"save_data": True,
15+
"save_correct": True,
16+
"dataset": "local",
17+
"data_format": "jsonl",
18+
"column_id": "id",
19+
"column_content": "content",
20+
})
21+
executor = Executor.exec_map["local"](input_args)
22+
result = executor.execute()
23+
# print(result)
24+
output_path = result[0].output_path
25+
assert os.path.exists(output_path)
26+
shutil.rmtree('outputs')
27+
28+
29+
if __name__ == '__main__':
30+
pytest.main(["-s", "-q"])

0 commit comments

Comments
 (0)