Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions dingo/exec/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,22 @@ def evaluate(self):
group (Any): _description_
group_type (str): _description_
"""
with concurrent.futures.ProcessPoolExecutor(max_workers=self.input_args.max_workers) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=self.input_args.max_workers) as thread_executor, \
concurrent.futures.ProcessPoolExecutor(max_workers=self.input_args.max_workers) as process_executor:
data_iter = self.load_data()
data_iter = itertools.islice(data_iter, self.input_args.start_index, None)
pbar = tqdm(total=None, unit='items')

def process_batch(batch: List):
futures = [executor.submit(self.evaluate_single_data, self.input_args.eval_group, data) for data in batch]
futures=[]
for group_type, group in Model.get_group(self.input_args.eval_group).items():
if group_type == 'rule':
futures += [process_executor.submit(self.evaluate_single_data, group_type, group, data) for data in batch]
elif group_type == 'prompt':
futures += [thread_executor.submit(self.evaluate_single_data, group_type, group, data) for data in batch]
else:
raise RuntimeError(f'Unsupported group type: {group_type}')

for future in concurrent.futures.as_completed(futures):
future.result()
if self.input_args.save_data:
Expand All @@ -116,18 +126,17 @@ def process_batch(batch: List):
tmp_good_info_list = self.good_info_list[self.good_info_index:len(self.good_info_list)]
self.good_info_index = len(self.good_info_list)
self.save_data(tmp_output_path, self.input_args, tmp_bad_info_list, tmp_good_info_list, tmp_summary)

with tqdm(total=None, unit='items') as pbar:
while True:
batch = list(itertools.islice(data_iter, self.input_args.batch_size))
if not batch:
break
process_batch(batch)
pbar.update()
while True:
batch = list(itertools.islice(data_iter, self.input_args.batch_size))
if not batch:
break
process_batch(batch)


log.debug('[Summary]: ' + str(self.summary))

def evaluate_single_data(self, group_name, data: MetaData):
def evaluate_single_data(self, group_type, group, data: MetaData):
result_info = ResultInfo(data_id=data.data_id, prompt=data.prompt, content=data.content)
if self.input_args.save_raw:
result_info.raw_data = data.raw_data
Expand All @@ -137,22 +146,22 @@ def evaluate_single_data(self, group_name, data: MetaData):
good_name_list = []
bad_reason_list = []
good_reason_list = []
for group_type, group in Model.get_group(group_name).items():
if group_type == 'rule':
r_i = self.evaluate_rule(group, data)
elif group_type == 'prompt':
r_i = self.evaluate_prompt(group, data)
else:
raise RuntimeError(f'Unsupported group type: {group_type}')
if r_i.error_status:
result_info.error_status = True
bad_type_list = bad_type_list + r_i.type_list
bad_name_list = bad_name_list + r_i.name_list
bad_reason_list = bad_reason_list + r_i.reason_list
else:
good_type_list = good_type_list + r_i.type_list
good_name_list = good_name_list + r_i.name_list
good_reason_list = good_reason_list + r_i.reason_list
# for group_type, group in Model.get_group(group_name).items():
if group_type == 'rule':
r_i = self.evaluate_rule(group, data)
elif group_type == 'prompt':
r_i = self.evaluate_prompt(group, data)
else:
raise RuntimeError(f'Unsupported group type: {group_type}')
if r_i.error_status:
result_info.error_status = True
bad_type_list = bad_type_list + r_i.type_list
bad_name_list = bad_name_list + r_i.name_list
bad_reason_list = bad_reason_list + r_i.reason_list
else:
good_type_list = good_type_list + r_i.type_list
good_name_list = good_name_list + r_i.name_list
good_reason_list = good_reason_list + r_i.reason_list
if result_info.error_status:
result_info.type_list = list(set(bad_type_list))
for name in bad_name_list:
Expand Down
5 changes: 1 addition & 4 deletions dingo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,8 @@ def decorator(root_class):
cls.rule_metric_type_map[metric_type].append(root_class)
root_class.metric_type = metric_type

@wraps(root_class)
def wrapped_function(*args, **kwargs):
return root_class(*args, **kwargs)

return wrapped_function
return root_class

return decorator

Expand Down
Loading