22import unittest
33import numpy as np
44import pandas as pd
5- from sklearn import metrics
65from wedpr_ml_toolkit .config .wedpr_ml_config import WeDPRMlConfigBuilder
76from wedpr_ml_toolkit .wedpr_ml_toolkit import WeDPRMlToolkit
87from wedpr_ml_toolkit .context .dataset_context import DatasetContext
98from wedpr_ml_toolkit .context .data_context import DataContext
109from wedpr_ml_toolkit .context .job_context import JobType
1110from wedpr_ml_toolkit .context .model_setting import PreprocessingSetting
11+ from wedpr_ml_toolkit .context .model_setting import ModelSetting
1212
1313
1414class WeDPRMlToolkitTestWrapper :
@@ -25,28 +25,19 @@ def test_submit_job(self):
2525 # x1到x10列,随机数
2626 ** {f'x{ i } ' : np .random .rand (100 ) for i in range (1 , 11 )}
2727 })
28+ # the dataset
2829 dataset1 = DatasetContext (storage_entrypoint = self .wedpr_ml_toolkit .get_storage_entry_point (),
2930 dataset_client = self .wedpr_ml_toolkit .get_dataset_client (),
3031 storage_workspace = self .wedpr_config .user_config .get_workspace_path (),
3132 dataset_id = "d-9743660607744005" ,
3233 is_label_holder = True )
3334 dataset1 .save_values (df , path = 'd-101' )
3435
35- # hdfs_path
36+ # the dataset
3637 dataset2 = DatasetContext (storage_entrypoint = self .wedpr_ml_toolkit .get_storage_entry_point (),
3738 dataset_client = self .wedpr_ml_toolkit .get_dataset_client (),
3839 dataset_id = "d-9743674298214405" )
39-
40- dataset2 .storage_client = None
41- # dataset2.load_values()
42- if dataset2 .storage_client is None :
43- # 支持更新dataset的values数据
44- df2 = pd .DataFrame ({
45- 'id' : np .arange (0 , 100 ), # id列,顺序整数
46- # x1到x10列,随机数
47- ** {f'z{ i } ' : np .random .rand (100 ) for i in range (1 , 11 )}
48- })
49- dataset2 .save_values (values = df2 )
40+ print (f"### dataset2 meta: { dataset2 .dataset_meta } " )
5041 if dataset1 .storage_client is not None :
5142 # save values to dataset1
5243 dataset1 .save_values (df )
@@ -70,20 +61,67 @@ def test_submit_job(self):
7061 psi_result = psi_job_context .fetch_job_result (psi_job_id , True )
7162 print (
7263 f"* fetch_job_result for psi job { psi_job_id } success, result: { psi_result } " )
64+ # build the psi result:
65+ psi_result_ctx = self .wedpr_ml_toolkit .build_result_context (
66+ psi_job_context , psi_result )
67+ print (f"* psi_result_ctx: { psi_result_ctx } " )
68+ (psi_result_values , psi_result_columns ,
69+ psi_result_shape ) = psi_result_ctx .result_dataset .load_values ()
70+ # obtain the intersection
71+ print (
72+ f"* psi result, psi_result_columns: { psi_result_columns } , "
73+ f"psi_result_shape: { psi_result_shape } , psi_result_values: { psi_result_values } " )
7374 # 初始化
7475 print (f"* build pre-processing data-context" )
7576 preprocessing_data = DataContext (dataset1 , dataset2 )
7677 preprocessing_job_context = self .wedpr_ml_toolkit .build_job_context (
7778 JobType .PREPROCESSING , project_id , preprocessing_data , PreprocessingSetting ())
7879 # 执行预处理任务
7980 print (f"* submit pre-processing job" )
80- fe_job_id = preprocessing_job_context .submit ()
81- print (f"* submit pre-processing job success, job_id: { fe_job_id } " )
82- fe_result = preprocessing_job_context .fetch_job_result (fe_job_id , True )
81+ preprocessing_job_id = preprocessing_job_context .submit ()
82+ print (
83+ f"* submit pre-processing job success, job_id: { preprocessing_job_id } " )
84+ preprocessing_result = preprocessing_job_context .fetch_job_result (
85+ preprocessing_job_id , True )
8386 print (
84- f"* fetch pre-processing job result success, job_id: { fe_job_id } , result: { fe_result } " )
87+ f"* fetch pre-processing job result success, job_id: { preprocessing_job_id } , result: { preprocessing_result } " )
8588 print (preprocessing_job_context .participant_id_list ,
8689 preprocessing_job_context .result_receiver_id_list )
90+ # build the context
91+ preprocessing_result_ctx = self .wedpr_ml_toolkit .build_result_context (preprocessing_job_context ,
92+ preprocessing_result )
93+ print (
94+ f"* preprocessing_result_ctx: { preprocessing_result_ctx .preprocessing_dataset } " )
95+ preprocessing_values , columns , shape = preprocessing_result_ctx .preprocessing_dataset .load_values ()
96+ print (
97+ f"* preprocessing_result_dataset, columns: { columns } , shape: { shape } " )
98+ # test xgb job
99+ xgb_data = DataContext (dataset1 , dataset2 )
100+ model_setting = ModelSetting ()
101+ model_setting .use_psi = True
102+ xgb_job_context = self .wedpr_ml_toolkit .build_job_context (
103+ job_type = JobType .XGB_TRAINING , project_id = project_id ,
104+ dataset = xgb_data ,
105+ model_setting = model_setting , id_fields = "id" )
106+ print (f"* construct xgb job context: participant_id_list: { xgb_job_context .participant_id_list } , "
107+ f"result_receiver_id_list: { xgb_job_context .result_receiver_id_list } " )
108+ xgb_job_id = xgb_job_context .submit ()
109+ print (f"* submit xgb job success, { xgb_job_id } " )
110+ xgb_job_result = xgb_job_context .fetch_job_result (xgb_job_id , True )
111+ print (f"* xgb job result: { xgb_job_result } " )
112+ xgb_job_context = self .wedpr_ml_toolkit .build_result_context (
113+ job_context = xgb_job_context , job_result_detail = xgb_job_result )
114+ print (f"* xgb job result: { xgb_job_context } " )
115+ # load the feature_importance information
116+ (feature_importance_value , feature_importance_cols , feature_importance_shape ) = \
117+ xgb_job_context .feature_importance_dataset .load_values ()
118+ print (f"* xgb feature importance information: { feature_importance_cols } , "
119+ f"{ feature_importance_shape } , { feature_importance_value } " )
120+ # load the evaluation information
121+ (evaluation_value , evaluation_cols , evaluation_shape ) = \
122+ xgb_job_context .evaluation_dataset .load_values ()
123+ print (f"* xgb evaluation information: { evaluation_cols } , "
124+ f"{ evaluation_shape } , { evaluation_value } " )
87125
88126 def test_query_job (self , job_id : str , block_until_finish ):
89127 job_result = self .wedpr_ml_toolkit .query_job_status (
0 commit comments