Skip to content

Commit 6e884d7

Browse files
authored
fix dataset bugs (#105)
* fix problems found * fix dataset bugs
1 parent 735f8bf commit 6e884d7

File tree

11 files changed

+417
-55
lines changed

11 files changed

+417
-55
lines changed

python/ppc_model/secure_model_base/secure_model_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(self,
119119

120120
model_predict_algorithm_str = common_func.get_config_value(
121121
"model_predict_algorithm", None, args, False)
122+
self.model_predict_algorithm = None
122123
if model_predict_algorithm_str is not None:
123124
self.model_predict_algorithm = json.loads(
124125
model_predict_algorithm_str)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
# In[4]:
5+
6+
7+
import sys
8+
import numpy as np
9+
import pandas as pd
10+
from wedpr_ml_toolkit.config.wedpr_ml_config import WeDPRMlConfigBuilder
11+
from wedpr_ml_toolkit.wedpr_ml_toolkit import WeDPRMlToolkit
12+
from wedpr_ml_toolkit.context.dataset_context import DatasetContext
13+
from wedpr_ml_toolkit.transport.wedpr_remote_dataset_client import DatasetMeta
14+
15+
16+
# In[5]:
17+
18+
19+
# 读取配置文件
20+
wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file('config.properties')
21+
wedpr_ml_toolkit = WeDPRMlToolkit(wedpr_config)
22+
23+
24+
# In[8]:
25+
26+
27+
# 注册 dataset,支持两种方式: pd.Dataframe, hdfs_path
28+
# 1. pd.Dataframe
29+
df = pd.DataFrame({
30+
'id': np.arange(0, 100), # id列,顺序整数
31+
'y': np.random.randint(0, 2, size=100),
32+
# x1到x10列,随机数
33+
**{f't{i}': np.random.rand(100) for i in range(1, 11)}
34+
})
35+
36+
dataset1_meta = DatasetMeta(file_path = "d-01", user_config = wedpr_config.user_config)
37+
38+
dataset1 = DatasetContext(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),
39+
storage_workspace=wedpr_config.user_config.get_workspace_path(),
40+
dataset_meta = dataset1_meta)
41+
print(f"* dataset1: {dataset1.dataset_meta}")
42+
dataset1.save_values(df)
43+
print(f"* updated dataset1: {dataset1}")
44+
45+
(values, cols, shape) = dataset1.load_values()
46+
print(f"* load values result: {cols}, {shape}, {values}")
47+
48+
49+
# In[11]:
50+
51+
52+
# 2. hdfs_path
53+
dataset2 = DatasetContext(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),
54+
dataset_client=wedpr_ml_toolkit.dataset_client,
55+
storage_workspace=wedpr_config.user_config.get_workspace_path(),
56+
dataset_id = "d-9866227816474629")
57+
print(f"* dataset2 meta: {dataset2}")
58+
59+
# load values
60+
(values, cols, shape) = dataset2.load_values(header=0)
61+
print(f"* dataset2 detail, cols: {cols}, shape: {shape}, values: {values}")
62+
63+
64+
# 支持更新dataset的values数据
65+
df2 = pd.DataFrame({
66+
'id': np.arange(0, 100), # id列,顺序整数
67+
**{f'w{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数
68+
})
69+
dataset2.save_values(values=df2)
70+
71+
print(f"*** updated dataset2 meta: {dataset2}")
72+
73+
(values, cols, shape) = dataset2.load_values(header=0)
74+
print(f"*** updated dataset2 detail, cols: {cols}, shape: {shape}, values: {values}")
75+
76+
77+
# In[ ]:
78+
79+
80+
81+
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
# In[1]:
5+
6+
7+
import sys
8+
import numpy as np
9+
import pandas as pd
10+
from wedpr_ml_toolkit.config.wedpr_ml_config import WeDPRMlConfigBuilder
11+
from wedpr_ml_toolkit.wedpr_ml_toolkit import WeDPRMlToolkit
12+
from wedpr_ml_toolkit.context.dataset_context import DatasetContext
13+
from wedpr_ml_toolkit.context.data_context import DataContext
14+
from wedpr_ml_toolkit.context.job_context import JobType
15+
from wedpr_ml_toolkit.context.model_setting import ModelSetting
16+
from wedpr_ml_toolkit.context.result.model_result_context import PredictResultContext
17+
from wedpr_ml_toolkit.context.result.model_result_context import TrainResultContext
18+
from wedpr_ml_toolkit.context.result.model_result_context import PreprocessingResultContext
19+
20+
21+
# In[2]:
22+
23+
24+
# 读取配置文件
25+
wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file('config.properties')
26+
wedpr_ml_toolkit = WeDPRMlToolkit(wedpr_config)
27+
28+
29+
# In[3]:
30+
31+
32+
# dataset1
33+
dataset1 = DatasetContext(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),
34+
dataset_client=wedpr_ml_toolkit.dataset_client,
35+
dataset_id = 'd-9743660607744005',
36+
is_label_holder=True)
37+
print(f"* load dataset1: {dataset1}")
38+
(values, cols, shapes) = dataset1.load_values()
39+
print(f"* dataset1 detail: {cols}, {shapes}")
40+
print(f"* dataset1 value: {values}")
41+
42+
43+
# In[4]:
44+
45+
46+
# dataset2
47+
dataset2 = DatasetContext(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),
48+
dataset_client = wedpr_ml_toolkit.dataset_client,
49+
dataset_id = "d-9743674298214405")
50+
print(f"* dataset2: {dataset2}")
51+
52+
# 构建 dataset context
53+
dataset = DataContext(dataset1, dataset2)
54+
print(dataset.datasets)
55+
56+
# init the job context
57+
project_id = "9737304249804806"
58+
59+
# 构造xgb任务配置
60+
model_setting = ModelSetting()
61+
model_setting.use_psi = True
62+
xgb_job_context = wedpr_ml_toolkit.build_job_context(
63+
job_type = JobType.XGB_TRAINING,
64+
project_id = project_id,
65+
dataset = dataset,
66+
model_setting = model_setting,
67+
id_fields = "id")
68+
print(f"* build xgb job context: {xgb_job_context}")
69+
70+
71+
# In[5]:
72+
73+
74+
# 执行xgb任务
75+
xgb_job_id = xgb_job_context.submit()
76+
print(xgb_job_id)
77+
78+
79+
# In[7]:
80+
81+
82+
# 获取xgb任务结果
83+
print(xgb_job_id)
84+
#xgb_job_id = "9868279583877126"
85+
xgb_result_detail = xgb_job_context.fetch_job_result(xgb_job_id, True)
86+
# load the result context
87+
xgb_result_context = wedpr_ml_toolkit.build_result_context(job_context=xgb_job_context,
88+
job_result_detail=xgb_result_detail)
89+
print(f"* xgb job result ctx: {xgb_result_context}")
90+
91+
xgb_test_dataset = xgb_result_context.test_result_dataset
92+
print(f"* xgb_test_dataset: {xgb_test_dataset}, file_path: {xgb_test_dataset.dataset_meta.file_path}")
93+
94+
(data, cols, shapes) = xgb_test_dataset.load_values()
95+
print(f"* test dataset detail, columns: {cols}, shape: {shapes}, value: {data}")
96+
97+
98+
# In[8]:
99+
100+
101+
# evaluation result
102+
result_context: TrainResultContext = xgb_result_context
103+
evaluation_result_dataset = result_context.evaluation_dataset
104+
(eval_data, cols, shape) = evaluation_result_dataset.load_values(header=0)
105+
print(f"* evaluation detail, col: {cols}, shape: {shape}, eval_data: {eval_data}")
106+
107+
108+
# In[9]:
109+
110+
111+
# feature importance
112+
feature_importance_dataset = result_context.feature_importance_dataset
113+
(feature_importance_data, cols, shape) = feature_importance_dataset.load_values()
114+
115+
print(f"* feature_importance detail, col: {cols}, shape: {shape}, feature_importance_data: {feature_importance_data}")
116+
117+
118+
# In[10]:
119+
120+
121+
# 预处理结果
122+
preprocessing_dataset = result_context.preprocessing_dataset
123+
(preprocessing_data, cols, shape) = preprocessing_dataset.load_values()
124+
125+
print(f"* preprocessing detail, col: {cols}, shape: {shape}, preprocessing_data: {preprocessing_data}")
126+
127+
128+
# In[11]:
129+
130+
131+
# 建模结果
132+
model_result_dataset = result_context.model_result_dataset
133+
(model_result, cols, shape) = model_result_dataset.load_values()
134+
135+
print(f"* model_result detail, col: {cols}, shape: {shape}, model_result: {model_result}")
136+
137+
138+
# In[12]:
139+
140+
141+
# 明文处理预测结果
142+
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, accuracy_score, f1_score, precision_score, recall_score
143+
import matplotlib.pyplot as plt
144+
145+
# 提取真实标签和预测概率
146+
y_true = data['class_label']
147+
y_pred_proba = data['class_pred']
148+
y_pred = np.where(y_pred_proba >= 0.5, 1, 0) # 二分类阈值设为0.5
149+
150+
# 计算评估指标
151+
accuracy = accuracy_score(y_true, y_pred)
152+
precision = precision_score(y_true, y_pred)
153+
recall = recall_score(y_true, y_pred)
154+
f1 = f1_score(y_true, y_pred)
155+
auc = roc_auc_score(y_true, y_pred_proba)
156+
157+
print(f"Accuracy: {accuracy:.2f}")
158+
print(f"Precision: {precision:.2f}")
159+
print(f"Recall: {recall:.2f}")
160+
print(f"F1 Score: {f1:.2f}")
161+
print(f"AUC: {auc:.2f}")
162+
163+
# ROC 曲线
164+
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
165+
plt.figure(figsize=(12, 5))
166+
167+
# ROC 曲线
168+
plt.subplot(1, 2, 1)
169+
plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
170+
plt.plot([0, 1], [0, 1], 'k--')
171+
plt.xlabel('False Positive Rate')
172+
plt.ylabel('True Positive Rate')
173+
plt.title('ROC Curve')
174+
plt.legend()
175+
176+
# 精确率-召回率曲线
177+
precision_vals, recall_vals, _ = precision_recall_curve(y_true, y_pred_proba)
178+
plt.subplot(1, 2, 2)
179+
plt.plot(recall_vals, precision_vals)
180+
plt.xlabel('Recall')
181+
plt.ylabel('Precision')
182+
plt.title('Precision-Recall Curve')
183+
184+
plt.tight_layout()
185+
plt.show()
186+
187+
188+
# In[13]:
189+
190+
191+
# 构造xgb预测任务配置
192+
predict_setting = ModelSetting()
193+
predict_setting.use_psi = True
194+
#model_predict_algorithm = {}
195+
#model_predict_algorithm.update({"setting": xgb_result_context.job_result_detail.model})
196+
predict_xgb_job_context = wedpr_ml_toolkit.build_job_context(
197+
job_type=JobType.XGB_PREDICTING,
198+
project_id = project_id,
199+
dataset= dataset,
200+
model_setting= predict_setting,
201+
id_fields = "id",
202+
predict_algorithm = xgb_result_context.job_result_detail.model_predict_algorithm)
203+
print(f"* predict_xgb_job_context: {predict_xgb_job_context}")
204+
205+
206+
# In[14]:
207+
208+
209+
# 执行xgb预测任务
210+
# xgb_job_id = '9868428439267334' # 测试时跳过创建新任务过程
211+
xgb_predict_job_id = predict_xgb_job_context.submit()
212+
print(xgb_predict_job_id)
213+
214+
215+
# In[15]:
216+
217+
218+
# query the job detail
219+
print(f"* xgb_predict_job_id: {xgb_predict_job_id}")
220+
221+
predict_xgb_job_result = predict_xgb_job_context.fetch_job_result(xgb_predict_job_id, True)
222+
223+
# generate the result context
224+
result_context = wedpr_ml_toolkit.build_result_context(job_context=predict_xgb_job_context,
225+
job_result_detail=predict_xgb_job_result)
226+
227+
xgb_predict_result_context : PredictResultContext = result_context
228+
print(f"* result_context is {xgb_predict_result_context}")
229+
230+
231+
# In[16]:
232+
233+
234+
# 明文处理预测结果
235+
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, accuracy_score, f1_score, precision_score, recall_score
236+
import matplotlib.pyplot as plt
237+
238+
239+
(data, cols, shapes) = xgb_predict_result_context.model_result_dataset.load_values(header = 0)
240+
241+
# 提取真实标签和预测概率
242+
y_true = data['class_label']
243+
y_pred_proba = data['class_pred']
244+
y_pred = np.where(y_pred_proba >= 0.5, 1, 0) # 二分类阈值设为0.5
245+
246+
# 计算评估指标
247+
accuracy = accuracy_score(y_true, y_pred)
248+
precision = precision_score(y_true, y_pred)
249+
recall = recall_score(y_true, y_pred)
250+
f1 = f1_score(y_true, y_pred)
251+
auc = roc_auc_score(y_true, y_pred_proba)
252+
253+
print(f"Accuracy: {accuracy:.2f}")
254+
print(f"Precision: {precision:.2f}")
255+
print(f"Recall: {recall:.2f}")
256+
print(f"F1 Score: {f1:.2f}")
257+
print(f"AUC: {auc:.2f}")
258+
259+
# ROC 曲线
260+
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
261+
plt.figure(figsize=(12, 5))
262+
263+
# ROC 曲线
264+
plt.subplot(1, 2, 1)
265+
plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
266+
plt.plot([0, 1], [0, 1], 'k--')
267+
plt.xlabel('False Positive Rate')
268+
plt.ylabel('True Positive Rate')
269+
plt.title('ROC Curve')
270+
plt.legend()
271+
272+
# 精确率-召回率曲线
273+
precision_vals, recall_vals, _ = precision_recall_curve(y_true, y_pred_proba)
274+
plt.subplot(1, 2, 2)
275+
plt.plot(recall_vals, precision_vals)
276+
plt.xlabel('Recall')
277+
plt.ylabel('Precision')
278+
plt.title('Precision-Recall Curve')
279+
280+
plt.tight_layout()
281+
plt.show()
282+
283+
284+
# In[ ]:
285+
286+
287+
288+

python/wedpr_ml_toolkit/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def run(self):
2020
setup_args = dict(
2121
name='wedpr_ml_toolkit',
2222
packages=find_packages(),
23-
version="1.0.0.dev-20241126",
23+
version="1.0.0.dev-20241129",
2424
description="wedpr-ml-toolkit: The ML toolkit for WeDPR",
2525
long_description_content_type="text/markdown",
2626
author="WeDPR Development Team",

0 commit comments

Comments
 (0)