Skip to content

Commit 0cc8107

Browse files
authored
feat: fix base function of ExecProto. And update summarize of spark. (#64)
1 parent 86d7a94 commit 0cc8107

File tree

3 files changed

+67
-68
lines changed

3 files changed

+67
-68
lines changed

dingo/exec/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77

88

99
class ExecProto(Protocol):
10-
def load_data(self, path: str, data_type: str) -> List[MetaData]:
10+
def load_data(self) -> Any:
1111
...
1212

13-
def execute(self) -> List[SummaryModel]:
13+
def execute(self) -> SummaryModel:
1414
...
1515

16-
def evaluate(self) -> SummaryModel:
16+
def evaluate(self):
1717
...
1818

19-
def summarize(self, inputs: MetaData) -> SummaryModel:
19+
def summarize(self, summary: SummaryModel) -> SummaryModel:
2020
...
2121

2222

dingo/exec/local.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def load_data(self) -> Generator[MetaData, None, None]:
4444
dataset: Dataset = dataset_cls(source=datasource)
4545
return dataset.get_data()
4646

47-
def execute(self) -> List[SummaryModel]:
47+
def execute(self) -> SummaryModel:
4848
log.setLevel(self.input_args.log_level)
4949
create_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
5050
Model.apply_config(self.input_args.custom_config, self.input_args.eval_group)
@@ -72,7 +72,7 @@ def execute(self) -> List[SummaryModel]:
7272
self.summary = self.summarize(self.summary)
7373
self.write_summary(self.summary.output_path, self.input_args, self.summary)
7474

75-
return [self.summary]
75+
return self.summary
7676

7777
def evaluate(self):
7878
"""
@@ -231,7 +231,6 @@ def evaluate_prompt(self, group: List[BasePrompt], d: MetaData) -> ResultInfo:
231231

232232
def summarize(self, summary: SummaryModel) -> SummaryModel:
233233
new_summary = copy.deepcopy(summary)
234-
new_summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
235234
if new_summary.total == 0:
236235
return new_summary
237236
new_summary.score = round(new_summary.num_good / new_summary.total * 100, 2)
@@ -241,6 +240,8 @@ def summarize(self, summary: SummaryModel) -> SummaryModel:
241240
new_summary.name_ratio[n] = round(new_summary.name_ratio[n] / new_summary.total, 6)
242241
new_summary.type_ratio = dict(sorted(new_summary.type_ratio.items()))
243242
new_summary.name_ratio = dict(sorted(new_summary.name_ratio.items()))
243+
244+
new_summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
244245
return new_summary
245246

246247
def write_single_data(self, path: str, input_args: InputArgs, result_info: ResultInfo):

dingo/exec/spark.py

Lines changed: 59 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import os
23
import time
34
import uuid
@@ -53,7 +54,7 @@ def __getstate__(self):
5354
def __setstate__(self, state):
5455
self.__dict__.update(state)
5556

56-
def _initialize_spark(self):
57+
def initialize_spark(self):
5758
"""Initialize Spark session if not already provided."""
5859
if self.spark_session is not None:
5960
return self.spark_session, self.spark_session.sparkContext
@@ -63,11 +64,18 @@ def _initialize_spark(self):
6364
else:
6465
raise ValueError('Both spark_session and spark_conf are None. Please provide one.')
6566

67+
def cleanup(self, spark):
68+
"""Clean up Spark resources."""
69+
if spark:
70+
spark.stop()
71+
if spark.sparkContext:
72+
spark.sparkContext.stop()
73+
6674
def load_data(self) -> RDD:
6775
"""Load and return the RDD data."""
6876
return self.spark_rdd
6977

70-
def execute(self) -> List[SummaryModel]:
78+
def execute(self) -> SummaryModel:
7179
"""Main execution method for Spark evaluation."""
7280
create_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
7381

@@ -80,7 +88,7 @@ def execute(self) -> List[SummaryModel]:
8088
self.llm = Model.get_llm(llm_name)
8189

8290
print("============= Init PySpark =============")
83-
spark, sc = self._initialize_spark()
91+
spark, sc = self.initialize_spark()
8492
self._sc = sc
8593
print("============== Init Done ===============")
8694

@@ -98,7 +106,7 @@ def execute(self) -> List[SummaryModel]:
98106

99107
# Evaluate data
100108
data_info_list = data_rdd.map(
101-
lambda x: self._evaluate_item(x, broadcast_group, broadcast_llm)
109+
lambda x: self.evaluate_item(x, broadcast_group, broadcast_llm)
102110
).persist() # Cache the evaluated data for multiple uses
103111

104112
# Filter and count bad/good items
@@ -119,26 +127,24 @@ def execute(self) -> List[SummaryModel]:
119127
score=round((total - num_bad) / total * 100, 2) if total > 0 else 0,
120128
num_good=total - num_bad,
121129
num_bad=num_bad,
122-
total=total,
123-
type_ratio={},
124-
name_ratio={}
130+
total=total
125131
)
126132
# Generate detailed summary
127-
self._summarize_results()
128-
129-
self.summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
130-
131-
return [self.summary]
133+
self.summary = self.summarize(self.summary)
134+
return self.summary
132135

133136
except Exception as e:
134137
raise e
135138
finally:
136139
if not self.input_args.save_data:
137-
self._cleanup(spark)
140+
self.cleanup(spark)
138141
else:
139142
self.spark_session = spark
140143

141-
def _evaluate_item(self, data_rdd_item, broadcast_group, broadcast_llm) -> Dict[str, Any]:
144+
def evaluate(self):
145+
pass
146+
147+
def evaluate_item(self, data_rdd_item, broadcast_group, broadcast_llm) -> Dict[str, Any]:
142148
"""Evaluate a single data item using broadcast variables."""
143149
data: MetaData = data_rdd_item
144150
result_info = ResultInfo(data_id=data.data_id, prompt=data.prompt, content=data.content)
@@ -158,9 +164,9 @@ def _evaluate_item(self, data_rdd_item, broadcast_group, broadcast_llm) -> Dict[
158164

159165
for group_type, group_items in group.items():
160166
if group_type == 'rule':
161-
r_i = self._evaluate_rule(group_items, data)
167+
r_i = self.evaluate_rule(group_items, data)
162168
elif group_type == 'prompt':
163-
r_i = self._evaluate_prompt(group_items, data, llm)
169+
r_i = self.evaluate_prompt(group_items, data, llm)
164170
else:
165171
raise RuntimeError(f'Unsupported group type: {group_type}')
166172

@@ -186,7 +192,7 @@ def _evaluate_item(self, data_rdd_item, broadcast_group, broadcast_llm) -> Dict[
186192

187193
return result_info.to_dict()
188194

189-
def _evaluate_rule(self, group: List[BaseRule], data: MetaData) -> ResultInfo:
195+
def evaluate_rule(self, group: List[BaseRule], data: MetaData) -> ResultInfo:
190196
"""Evaluate data against a group of rules."""
191197
result_info = ResultInfo(data_id=data.data_id, prompt=data.prompt, content=data.content)
192198

@@ -218,7 +224,7 @@ def _evaluate_rule(self, group: List[BaseRule], data: MetaData) -> ResultInfo:
218224

219225
return result_info
220226

221-
def _evaluate_prompt(self, group: List[BasePrompt], data: MetaData, llm: BaseLLM) -> ResultInfo:
227+
def evaluate_prompt(self, group: List[BasePrompt], data: MetaData, llm: BaseLLM) -> ResultInfo:
222228
"""Evaluate data against a group of prompts using LLM."""
223229
if llm is None:
224230
raise ValueError("LLM is required for prompt evaluation")
@@ -254,37 +260,42 @@ def _evaluate_prompt(self, group: List[BasePrompt], data: MetaData, llm: BaseLLM
254260

255261
return result_info
256262

257-
def _summarize_results(self):
263+
def summarize(self, summary: SummaryModel) -> SummaryModel:
258264
"""Generate summary statistics from bad info list."""
259-
if not self.bad_info_list:
260-
return
261-
262-
# Calculate type ratios
263-
type_counts = (
264-
self.bad_info_list
265-
.flatMap(lambda x: [(t, 1) for t in x['type_list']])
266-
.reduceByKey(lambda a, b: a + b)
267-
.collectAsMap()
268-
)
269-
self.summary.type_ratio = {
270-
k: round(v / self.summary.total, 6)
271-
for k, v in type_counts.items()
272-
}
273-
274-
# Calculate name ratios
275-
name_counts = (
276-
self.bad_info_list
277-
.flatMap(lambda x: [(n, 1) for n in x['name_list']])
278-
.reduceByKey(lambda a, b: a + b)
279-
.collectAsMap()
280-
)
281-
self.summary.name_ratio = {
282-
k: round(v / self.summary.total, 6)
283-
for k, v in name_counts.items()
284-
}
285-
286-
self.summary.type_ratio = dict(sorted(self.summary.type_ratio.items()))
287-
self.summary.name_ratio = dict(sorted(self.summary.name_ratio.items()))
265+
def collect_ratio(data_info_list, key_name: str, total_count: int):
266+
data_info_counts = (
267+
data_info_list
268+
.flatMap(lambda x: [(t, 1) for t in x[key_name]])
269+
.reduceByKey(lambda a, b: a + b)
270+
.collectAsMap()
271+
)
272+
return {
273+
k: round(v / total_count, 6)
274+
for k, v in data_info_counts.items()
275+
}
276+
277+
278+
new_summary = copy.deepcopy(self.summary)
279+
if not self.bad_info_list and not self.good_info_list:
280+
return new_summary
281+
if not self.bad_info_list and self.good_info_list:
282+
if not self.input_args.save_correct:
283+
return new_summary
284+
285+
new_summary.type_ratio = collect_ratio(self.bad_info_list, 'type_list', new_summary.total)
286+
new_summary.name_ratio = collect_ratio(self.bad_info_list, 'name_list', new_summary.total)
287+
288+
if self.input_args.save_correct:
289+
type_ratio_correct = collect_ratio(self.good_info_list, 'type_list', new_summary.total)
290+
name_ratio_correct = collect_ratio(self.good_info_list, 'name_list', new_summary.total)
291+
new_summary.type_ratio.update(type_ratio_correct)
292+
new_summary.name_ratio.update(name_ratio_correct)
293+
294+
new_summary.type_ratio = dict(sorted(new_summary.type_ratio.items()))
295+
new_summary.name_ratio = dict(sorted(new_summary.name_ratio.items()))
296+
297+
new_summary.finish_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
298+
return new_summary
288299

289300
def get_summary(self):
290301
return self.summary
@@ -314,16 +325,3 @@ def get_good_info_list(self):
314325
}
315326
})
316327
return self.good_info_list
317-
318-
def save_data(self, start_time):
319-
"""Save output data to specified path."""
320-
output_path = os.path.join(self.input_args.output_path, start_time)
321-
model_path = os.path.join(output_path, self.input_args.eval_group)
322-
os.makedirs(model_path, exist_ok=True)
323-
324-
def _cleanup(self, spark):
325-
"""Clean up Spark resources."""
326-
if spark:
327-
spark.stop()
328-
if spark.sparkContext:
329-
spark.sparkContext.stop()

0 commit comments

Comments
 (0)