Skip to content

Commit 2c1aebb

Browse files
authored
fix handshake bug (#80)
* remove handle_local_psi * fix handshake bug
1 parent fa96b32 commit 2c1aebb

File tree

7 files changed

+66
-67
lines changed

7 files changed

+66
-67
lines changed

cpp/wedpr-computing/ppc-psi/src/cm2020-psi/Common.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,14 @@ inline uint32_t dedupDataBatch(ppc::io::DataBatch::Ptr dataBatch)
7272
return 0;
7373
}
7474
auto& data = dataBatch->mutableData();
75-
tbb::parallel_sort(data->begin(), data->end());
76-
auto unique_end = std::unique(data->begin(), data->end());
75+
// Note: the header field should not been sorted
76+
auto it = data->begin() + 1;
77+
if (it >= data->end())
78+
{
79+
return data->size();
80+
}
81+
tbb::parallel_sort(it, data->end());
82+
auto unique_end = std::unique(it, data->end());
7783
data->erase(unique_end, data->end());
7884
return data->size();
7985
}

python/ppc_model/common/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class Context(BaseContext):
88

99
def __init__(self, job_id: str, task_id: str, components: Initializer, role: TaskRole = None):
1010
super().__init__(job_id, components.config_data['JOB_TEMP_DIR'])
11+
self.my_agency_id = components.config_data['AGENCY_ID']
1112
self.task_id = task_id
1213
self.components = components
1314
self.role = role

python/ppc_model/interface/model_base.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,23 @@ class ModelBase(ABC):
1010
def __init__(self, ctx):
1111
self.ctx = ctx
1212
self.ctx.model_router = self.ctx.components.model_router
13-
if self.ctx.role == TaskRole.ACTIVE_PARTY:
14-
self.__active_handshake__()
15-
else:
16-
self.__passive_handshake__()
13+
self.__handshake__()
1714

18-
def __active_handshake__(self):
15+
def __handshake__(self):
1916
# handshake with all passive parties
20-
for i in range(1, len(self.ctx.participant_id_list)):
17+
for i in range(0, len(self.ctx.participant_id_list)):
2118
participant = self.ctx.participant_id_list[i]
19+
if participant == self.ctx.my_agency_id:
20+
continue
2221
self.ctx.components.logger().info(
23-
f"Active: send handshake to passive party: {participant}")
22+
f"Send handshake to party: {participant}")
2423
self.ctx.model_router.handshake(self.ctx.task_id, participant)
25-
# wait for handshake response from the passive parties
26-
self.ctx.components.logger().info(
27-
f"Active: wait for handshake from passive party: {participant}")
28-
self.ctx.model_router.wait_for_handshake(self.ctx.task_id)
2924

30-
def __passive_handshake__(self):
31-
self.ctx.components.logger().info(
32-
f"Passive: send handshake to active party: {self.ctx.participant_id_list[0]}")
33-
# send handshake to the active party
34-
self.ctx.model_router.handshake(
35-
self.ctx.task_id, self.ctx.participant_id_list[0])
36-
# wait for handshake for the active party
25+
# wait for handshake response from the passive parties
3726
self.ctx.components.logger().info(
38-
f"Passive: wait for Handshake from active party: {self.ctx.participant_id_list[0]}")
39-
self.ctx.model_router.wait_for_handshake(self.ctx.task_id)
27+
f"Wait for handshake from all parities")
28+
self.ctx.model_router.wait_for_handshake(
29+
self.ctx.task_id, self.ctx.participant_id_list, self.ctx.my_agency_id)
4030

4131
def fit(
4232
self,

python/ppc_model/network/wedpr_model_transport.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,37 @@ def handshake(self, task_id, participant):
9393
seq=0, payload=bytes(),
9494
timeout=self.transport.send_msg_timeout)
9595

96-
def wait_for_handshake(self, task_id):
97-
topic = ModelTransport.get_topic_without_agency(
98-
task_id, BaseMessage.Handshake.value)
99-
self.transport.transport.register_topic(topic)
100-
result = self.transport.pop_by_topic(topic=topic, task_id=task_id)
101-
102-
if result is None:
103-
raise Exception(f"wait_for_handshake failed!")
104-
self.logger.info(
105-
f"wait_for_handshake success, task: {task_id}, detail: {result}")
106-
with self._rw_lock.gen_wlock():
107-
from_inst = result.get_header().get_src_inst()
96+
def __all_connected__(self, task_id, participant_id_list, self_agency_id):
97+
with self._rw_lock.gen_rlock():
10898
if task_id not in self.router_info.keys():
109-
self.router_info.update({task_id: dict()})
110-
self.router_info.get(task_id).update(
111-
{from_inst: result.get_header().get_src_node().decode("utf-8")})
99+
return False
100+
for participant in participant_id_list:
101+
if participant == self_agency_id:
102+
continue
103+
if participant not in self.router_info.get(task_id).keys():
104+
return False
105+
self.logger.info(
106+
f"__all_connected__, task: {task_id}, participant_id_list: {participant_id_list}")
107+
return True
108+
109+
def wait_for_handshake(self, task_id, participant_id_list: list, self_agency_id):
110+
while not self.__all_connected__(task_id, participant_id_list, self_agency_id):
111+
time.sleep(0.04)
112+
topic = ModelTransport.get_topic_without_agency(
113+
task_id, BaseMessage.Handshake.value)
114+
self.transport.transport.register_topic(topic)
115+
result = self.transport.pop_by_topic(topic=topic, task_id=task_id)
116+
117+
if result is None:
118+
raise Exception(f"wait_for_handshake failed!")
119+
self.logger.info(
120+
f"wait_for_handshake success, task: {task_id}, detail: {result}")
121+
with self._rw_lock.gen_wlock():
122+
from_inst = result.get_header().get_src_inst()
123+
if task_id not in self.router_info.keys():
124+
self.router_info.update({task_id: dict()})
125+
self.router_info.get(task_id).update(
126+
{from_inst: result.get_header().get_src_node().decode("utf-8")})
112127

113128
def on_task_finish(self, task_id):
114129
topic = ModelTransport.get_topic_without_agency(

python/ppc_model/preprocessing/local_processing/local_processing_party.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,10 @@ def processing(self):
3232
if need_psi and (not utils.file_exists(psi_result_path)):
3333
storage_client.download_file(
3434
self.ctx.remote_psi_result_path, psi_result_path)
35-
self.handle_local_psi_result(
36-
self.ctx.remote_psi_result_path, psi_result_path)
3735
log.info(
3836
f"prepare_xgb_after_psi, make_dataset_to_xgb_data_plus_psi_data, dataset_file_path={dataset_file_path}, "
39-
f"psi_result_path={dataset_file_path}, model_prepare_file={model_prepare_file}")
37+
f"psi_result_path={psi_result_path}, model_prepare_file={model_prepare_file}, "
38+
f"remote_psi_result_path: {self.ctx.remote_psi_result_path}")
4039
self.make_dataset_to_xgb_data()
4140
storage_client.upload_file(
4241
model_prepare_file, job_id + os.sep + BaseContext.MODEL_PREPARE_FILE)
@@ -48,25 +47,6 @@ def processing(self):
4847
log.info(
4948
f"call prepare_xgb_after_psi success, job_id={job_id}, timecost: {time.time() - start}")
5049

51-
def handle_local_psi_result(self, remote_psi_result_path, local_psi_result_path):
52-
try:
53-
log = self.ctx.components.logger()
54-
log.info(
55-
f"handle_local_psi_result: start handle_local_psi_result, psi_result_path={local_psi_result_path}")
56-
with open(local_psi_result_path, 'r+', encoding='utf-8') as psi_result_file:
57-
content = psi_result_file.read()
58-
psi_result_file.seek(0, 0)
59-
psi_result_file.write('id\n' + content)
60-
log.info(
61-
f"handle_local_psi_result: call handle_local_psi_result success, psi_result_path={local_psi_result_path}")
62-
# upload to remote
63-
self.ctx.components.storage_client.upload_file(
64-
local_psi_result_path, remote_psi_result_path)
65-
except BaseException as e:
66-
log.exception(
67-
f"handle_local_psi_result: handle_local_psi_result, psi_result_path={local_psi_result_path}, error:{e}")
68-
raise e
69-
7050
def make_dataset_to_xgb_data(self):
7151
log = self.ctx.components.logger()
7252
dataset_file_path = self.ctx.dataset_file_path

python/ppc_model/secure_lr/vertical/active_party.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def fit(
4141
self.log.info(
4242
f'task {self.ctx.task_id}: Starting the lr on the active party.')
4343
self._init_active_data()
44-
45-
max_iter = self._init_iter(self.dataset.train_X.shape[0],
44+
45+
max_iter = self._init_iter(self.dataset.train_X.shape[0],
4646
self.params.epochs, self.params.batch_size)
47+
self.log.info(f"task: {self.ctx.task_id}, max_iter: {max_iter}")
4748
for _ in range(max_iter):
4849
self._iter_id += 1
4950
start_time = time.time()
@@ -59,7 +60,8 @@ def fit(
5960
self._build_iter(feature_select, idx)
6061

6162
# 预测
62-
self._train_praba = self._predict_tree(self.dataset.train_X, LRMessage.PREDICT_LEAF_MASK.value)
63+
self._train_praba = self._predict_tree(
64+
self.dataset.train_X, LRMessage.PREDICT_LEAF_MASK.value)
6365
# print('train_praba', set(self._train_praba))
6466

6567
# 评估
@@ -69,10 +71,11 @@ def fit(
6971
self.log.info(
7072
f'task {self.ctx.task_id}: iter-{self._iter_id}, auc: {auc}.')
7173
self.log.info(f'task {self.ctx.task_id}: Ending iter-{self._iter_id}, '
72-
f'time_costs: {time.time() - start_time}s.')
74+
f'time_costs: {time.time() - start_time}s.')
7375

7476
# 预测验证集
75-
self._test_praba = self._predict_tree(self.dataset.test_X, LRMessage.TEST_LEAF_MASK.value)
77+
self._test_praba = self._predict_tree(
78+
self.dataset.test_X, LRMessage.TEST_LEAF_MASK.value)
7679
if not self.params.silent and self.dataset.test_y is not None:
7780
auc = Evaluation.fevaluation(
7881
self.dataset.test_y, self._test_praba)['auc']
@@ -89,7 +92,8 @@ def predict(self, dataset: SecureDataset = None) -> np.ndarray:
8992
if dataset is None:
9093
dataset = self.dataset
9194

92-
test_praba = self._predict_tree(dataset.test_X, LRMessage.VALID_LEAF_MASK.value)
95+
test_praba = self._predict_tree(
96+
dataset.test_X, LRMessage.VALID_LEAF_MASK.value)
9397
self._test_praba = test_praba
9498

9599
if dataset.test_y is not None:
@@ -139,8 +143,10 @@ def _build_iter(self, feature_select, idx):
139143
public_key_list, d_other_list, partner_index_list = self._receive_d_instance_list()
140144
deriv = self._calculate_deriv(x_, d, partner_index_list, d_other_list)
141145

142-
self._train_weights -= self.params.learning_rate * deriv.astype('float')
143-
self._train_weights[~np.isin(np.arange(len(self._train_weights)), feature_select)] = 0
146+
self._train_weights -= self.params.learning_rate * \
147+
deriv.astype('float')
148+
self._train_weights[~np.isin(
149+
np.arange(len(self._train_weights)), feature_select)] = 0
144150

145151
def _predict_tree(self, X, key_type):
146152
train_g = self._loss_func.dot_product(X, self._train_weights)

python/tools/install.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
apt-get install pkg-config python3-dev default-libmysqlclient-dev build-essential
2+
apt-get install graphviz

0 commit comments

Comments
 (0)