Skip to content

Commit 735f8bf

Browse files
authored
update the meta information when save_values (#104)
* refactor all path of result files to result_context * update the meta information when save_values * complement result handler
1 parent 118a036 commit 735f8bf

File tree

15 files changed

+417
-320
lines changed

15 files changed

+417
-320
lines changed

python/wedpr_ml_toolkit/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def run(self):
2020
setup_args = dict(
2121
name='wedpr_ml_toolkit',
2222
packages=find_packages(),
23-
version="1.0.0.dev-20241125",
23+
version="1.0.0.dev-20241126",
2424
description="wedpr-ml-toolkit: The ML toolkit for WeDPR",
2525
long_description_content_type="text/markdown",
2626
author="WeDPR Development Team",

python/wedpr_ml_toolkit/test/test_ml_toolkit.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import unittest
33
import numpy as np
44
import pandas as pd
5-
from sklearn import metrics
65
from wedpr_ml_toolkit.config.wedpr_ml_config import WeDPRMlConfigBuilder
76
from wedpr_ml_toolkit.wedpr_ml_toolkit import WeDPRMlToolkit
87
from wedpr_ml_toolkit.context.dataset_context import DatasetContext
98
from wedpr_ml_toolkit.context.data_context import DataContext
109
from wedpr_ml_toolkit.context.job_context import JobType
1110
from wedpr_ml_toolkit.context.model_setting import PreprocessingSetting
11+
from wedpr_ml_toolkit.context.model_setting import ModelSetting
1212

1313

1414
class WeDPRMlToolkitTestWrapper:
@@ -25,28 +25,19 @@ def test_submit_job(self):
2525
# x1到x10列,随机数
2626
**{f'x{i}': np.random.rand(100) for i in range(1, 11)}
2727
})
28+
# the dataset
2829
dataset1 = DatasetContext(storage_entrypoint=self.wedpr_ml_toolkit.get_storage_entry_point(),
2930
dataset_client=self.wedpr_ml_toolkit.get_dataset_client(),
3031
storage_workspace=self.wedpr_config.user_config.get_workspace_path(),
3132
dataset_id="d-9743660607744005",
3233
is_label_holder=True)
3334
dataset1.save_values(df, path='d-101')
3435

35-
# hdfs_path
36+
# the dataset
3637
dataset2 = DatasetContext(storage_entrypoint=self.wedpr_ml_toolkit.get_storage_entry_point(),
3738
dataset_client=self.wedpr_ml_toolkit.get_dataset_client(),
3839
dataset_id="d-9743674298214405")
39-
40-
dataset2.storage_client = None
41-
# dataset2.load_values()
42-
if dataset2.storage_client is None:
43-
# 支持更新dataset的values数据
44-
df2 = pd.DataFrame({
45-
'id': np.arange(0, 100), # id列,顺序整数
46-
# x1到x10列,随机数
47-
**{f'z{i}': np.random.rand(100) for i in range(1, 11)}
48-
})
49-
dataset2.save_values(values=df2)
40+
print(f"### dataset2 meta: {dataset2.dataset_meta}")
5041
if dataset1.storage_client is not None:
5142
# save values to dataset1
5243
dataset1.save_values(df)
@@ -70,20 +61,67 @@ def test_submit_job(self):
7061
psi_result = psi_job_context.fetch_job_result(psi_job_id, True)
7162
print(
7263
f"* fetch_job_result for psi job {psi_job_id} success, result: {psi_result}")
64+
# build the psi result:
65+
psi_result_ctx = self.wedpr_ml_toolkit.build_result_context(
66+
psi_job_context, psi_result)
67+
print(f"* psi_result_ctx: {psi_result_ctx}")
68+
(psi_result_values, psi_result_columns,
69+
psi_result_shape) = psi_result_ctx.result_dataset.load_values()
70+
# obtain the intersection
71+
print(
72+
f"* psi result, psi_result_columns: {psi_result_columns}, "
73+
f"psi_result_shape: {psi_result_shape}, psi_result_values: {psi_result_values}")
7374
# 初始化
7475
print(f"* build pre-processing data-context")
7576
preprocessing_data = DataContext(dataset1, dataset2)
7677
preprocessing_job_context = self.wedpr_ml_toolkit.build_job_context(
7778
JobType.PREPROCESSING, project_id, preprocessing_data, PreprocessingSetting())
7879
# 执行预处理任务
7980
print(f"* submit pre-processing job")
80-
fe_job_id = preprocessing_job_context.submit()
81-
print(f"* submit pre-processing job success, job_id: {fe_job_id}")
82-
fe_result = preprocessing_job_context.fetch_job_result(fe_job_id, True)
81+
preprocessing_job_id = preprocessing_job_context.submit()
82+
print(
83+
f"* submit pre-processing job success, job_id: {preprocessing_job_id}")
84+
preprocessing_result = preprocessing_job_context.fetch_job_result(
85+
preprocessing_job_id, True)
8386
print(
84-
f"* fetch pre-processing job result success, job_id: {fe_job_id}, result: {fe_result}")
87+
f"* fetch pre-processing job result success, job_id: {preprocessing_job_id}, result: {preprocessing_result}")
8588
print(preprocessing_job_context.participant_id_list,
8689
preprocessing_job_context.result_receiver_id_list)
90+
# build the context
91+
preprocessing_result_ctx = self.wedpr_ml_toolkit.build_result_context(preprocessing_job_context,
92+
preprocessing_result)
93+
print(
94+
f"* preprocessing_result_ctx: {preprocessing_result_ctx.preprocessing_dataset}")
95+
preprocessing_values, columns, shape = preprocessing_result_ctx.preprocessing_dataset.load_values()
96+
print(
97+
f"* preprocessing_result_dataset, columns: {columns}, shape: {shape}")
98+
# test xgb job
99+
xgb_data = DataContext(dataset1, dataset2)
100+
model_setting = ModelSetting()
101+
model_setting.use_psi = True
102+
xgb_job_context = self.wedpr_ml_toolkit.build_job_context(
103+
job_type=JobType.XGB_TRAINING, project_id=project_id,
104+
dataset=xgb_data,
105+
model_setting=model_setting, id_fields="id")
106+
print(f"* construct xgb job context: participant_id_list: {xgb_job_context.participant_id_list}, "
107+
f"result_receiver_id_list: {xgb_job_context.result_receiver_id_list}")
108+
xgb_job_id = xgb_job_context.submit()
109+
print(f"* submit xgb job success, {xgb_job_id}")
110+
xgb_job_result = xgb_job_context.fetch_job_result(xgb_job_id, True)
111+
print(f"* xgb job result: {xgb_job_result}")
112+
xgb_job_context = self.wedpr_ml_toolkit.build_result_context(
113+
job_context=xgb_job_context, job_result_detail=xgb_job_result)
114+
print(f"* xgb job result: {xgb_job_context}")
115+
# load the feature_importance information
116+
(feature_importance_value, feature_importance_cols, feature_importance_shape) = \
117+
xgb_job_context.feature_importance_dataset.load_values()
118+
print(f"* xgb feature importance information: {feature_importance_cols}, "
119+
f"{feature_importance_shape}, {feature_importance_value}")
120+
# load the evaluation information
121+
(evaluation_value, evaluation_cols, evaluation_shape) = \
122+
xgb_job_context.evaluation_dataset.load_values()
123+
print(f"* xgb evaluation information: {evaluation_cols}, "
124+
f"{evaluation_shape}, {evaluation_value}")
87125

88126
def test_query_job(self, job_id: str, block_until_finish):
89127
job_result = self.wedpr_ml_toolkit.query_job_status(

python/wedpr_ml_toolkit/wedpr_ml_toolkit/common/utils/constant.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@ class Constant:
77
DEFAULT_SUBMIT_JOB_URI = f'{WEDPR_API_PREFIX}project/submitJob'
88
DEFAULT_QUERY_JOB_STATUS_URL = f'{WEDPR_API_PREFIX}project/queryJobByCondition'
99
DEFAULT_QUERY_JOB_DETAIL_URL = f'{WEDPR_API_PREFIX}scheduler/queryJobDetail'
10+
# the dataset related url
1011
DEFAULT_QUERY_DATASET_URL = f'{WEDPR_API_PREFIX}dataset/queryDataset'
12+
DEFAULT_UPDATED_DATASET_URL = f'{WEDPR_API_PREFIX}dataset/updateDatasetMeta'
1113
PSI_RESULT_FILE = "psi_result.csv"
1214

1315
FEATURE_BIN_FILE = "feature_bin.json"
14-
TEST_MODEL_OUTPUT_FILE = "test_output.csv"
15-
TRAIN_MODEL_OUTPUT_FILE = "train_output.csv"
16-
17-
FE_RESULT_FILE = "fe_result.csv"
16+
XGB_TREE_PREFIX = "xgb_tree"
17+
MODEL_RESULT_FILE = XGB_TREE_PREFIX + '.json'
18+
PREPROCESSING_RESULT_FILE = "preprocessing_result.csv"
19+
EVALUATION_TABLE_FILE = "mpc_xgb_evaluation_table.csv"
20+
FEATURE_IMPORTANCE_FILE = "xgb_result_feature_importance_table.csv"
21+
FEATURE_SELECTION_FILE = "xgb_result_column_info_selected.csv"
22+
MODEL_FILE = "model_enc.kpl"
23+
WOE_IV_FILE = "woe_iv.csv"

python/wedpr_ml_toolkit/wedpr_ml_toolkit/config/wedpr_ml_config.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ def __init__(self, polling_interval_s: int = 5, max_retries: int = 2, retry_dela
3232

3333

3434
class DatasetConfig(BaseObject):
35-
def __init__(self, query_dataset_uri=Constant.DEFAULT_QUERY_DATASET_URL):
35+
def __init__(self,
36+
query_dataset_uri=Constant.DEFAULT_QUERY_DATASET_URL,
37+
update_dataset_uri=Constant.DEFAULT_UPDATED_DATASET_URL):
3638
self.query_dataset_uri = query_dataset_uri
39+
self.update_dataset_uri = update_dataset_uri
3740

3841

3942
class StorageConfig(BaseObject):
@@ -56,11 +59,6 @@ def __init__(self, timeout_seconds=3):
5659
self.timeout_seconds = timeout_seconds
5760

5861

59-
class AgencyConfig(BaseObject):
60-
def __init__(self, agency_name=None):
61-
self.agency_name = agency_name
62-
63-
6462
class WeDPRMlConfig:
6563
def __init__(self, config_dict):
6664
self.auth_config = AuthConfig()
@@ -73,8 +71,6 @@ def __init__(self, config_dict):
7371
self.user_config.set_params(**config_dict)
7472
self.http_config = HttpConfig()
7573
self.http_config.set_params(**config_dict)
76-
self.agency_config = AgencyConfig()
77-
self.agency_config.set_params(**config_dict)
7874
self.dataset_config = DatasetConfig()
7975

8076

python/wedpr_ml_toolkit/wedpr_ml_toolkit/context/data_context.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,6 @@ class DataContext:
99
def __init__(self, *datasets):
1010
self.datasets = list(datasets)
1111

12-
self._check_datasets()
13-
14-
def _save_dataset(self, dataset: DatasetContext):
15-
file_path = dataset.dataset_meta.file_path
16-
if file_path is None:
17-
dataset.dataset_id = utils.make_id(
18-
utils.IdPrefixEnum.DATASET.value)
19-
file_path = os.path.join(
20-
dataset.storage_workspace, dataset.dataset_id)
21-
if dataset.storage_client is not None:
22-
dataset.storage_client.upload(
23-
dataset.values, file_path)
24-
25-
def _check_datasets(self):
26-
for dataset in self.datasets:
27-
self._save_dataset(dataset)
28-
2912
def to_psi_format(self, merge_filed, result_receiver_id_list):
3013
dataset_psi = []
3114
for dataset in self.datasets:
@@ -40,14 +23,10 @@ def to_psi_format(self, merge_filed, result_receiver_id_list):
4023

4124
def __generate_dataset_info__(self, id_field: str, receive_result: bool, label_provider: bool, dataset: DatasetContext):
4225
return {"idFields": [id_field],
43-
"dataset": {"owner": dataset.dataset_meta.ownerUserName,
44-
"ownerAgency": dataset.dataset_meta.ownerAgencyName,
45-
"path": dataset.dataset_meta.file_path,
46-
"storageTypeStr": "HDFS",
26+
"dataset": {"ownerAgency": dataset.dataset_meta.ownerAgencyName,
4727
"datasetID": dataset.dataset_id},
4828
"receiveResult": receive_result,
49-
"labelProvider": label_provider
50-
}
29+
"labelProvider": label_provider}
5130

5231
def to_model_formort(self, merge_filed, result_receiver_id_list):
5332
dataset_model = []

python/wedpr_ml_toolkit/wedpr_ml_toolkit/context/dataset_context.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from wedpr_ml_toolkit.transport.storage_entrypoint import StorageEntryPoint
44
from wedpr_ml_toolkit.transport.wedpr_remote_dataset_client import WeDPRDatasetClient
55
from wedpr_ml_toolkit.transport.wedpr_remote_dataset_client import DatasetMeta
6+
import io
67

78

89
class DatasetContext:
@@ -27,6 +28,10 @@ def __init__(self,
2728
# the storage workspace
2829
self.storage_workspace = storage_workspace
2930

31+
def __repr__(self):
32+
return f"dataset_id: {self.dataset_id}, " \
33+
f"dataset_meta: {self.dataset_meta}"
34+
3035
def load_values(self, header=None):
3136
# 加载hdfs的数据集
3237
if self.storage_client is not None:
@@ -37,6 +42,25 @@ def load_values(self, header=None):
3742
return values, values.columns, values.shape
3843

3944
def save_values(self, values: pd.DataFrame = None, path=None):
45+
# no values to save
46+
if values is None:
47+
return
48+
csv_buffer = io.StringIO()
49+
values.to_csv(csv_buffer, index=False)
50+
value_bytes = csv_buffer.getvalue()
51+
# update the meta firstly
52+
if path is None and self.dataset_meta is not None and self.dataset_meta.datasetId is not None:
53+
columns = values.columns.to_list()
54+
dataset_meta = DatasetMeta(dataset_id=self.dataset_meta.datasetId,
55+
dataset_fields=','.join(columns),
56+
dataset_size=len(value_bytes),
57+
dataset_record_count=len(values),
58+
dataset_column_count=len(columns))
59+
self.dataset_client.update_dataset(dataset_meta)
60+
self.dataset_meta.datasetFields = ','.join(columns)
61+
self.dataset_meta.dataset_record_count = len(values)
62+
self.dataset_meta.columnCount = len(columns)
63+
# update the content
4064
target_path = self.dataset_meta.file_path
4165
# 保存数据到hdfs目录
4266
if path is not None:
@@ -47,7 +71,7 @@ def save_values(self, values: pd.DataFrame = None, path=None):
4771
target_path = os.path.join(
4872
self.storage_workspace, target_path)
4973
if self.storage_client is not None:
50-
self.storage_client.upload(values, target_path)
74+
self.storage_client.upload_bytes(value_bytes, target_path)
5175

5276
def update_path(self, path: str = None):
5377
# 将数据集存入hdfs相同路径,替换旧数据集

0 commit comments

Comments
 (0)