@@ -41,9 +41,10 @@ def fit(
4141 self .log .info (
4242 f'task { self .ctx .task_id } : Starting the lr on the active party.' )
4343 self ._init_active_data ()
44-
45- max_iter = self ._init_iter (self .dataset .train_X .shape [0 ],
44+
45+ max_iter = self ._init_iter (self .dataset .train_X .shape [0 ],
4646 self .params .epochs , self .params .batch_size )
47+ self .log .info (f"task: { self .ctx .task_id } , max_iter: { max_iter } " )
4748 for _ in range (max_iter ):
4849 self ._iter_id += 1
4950 start_time = time .time ()
@@ -59,7 +60,8 @@ def fit(
5960 self ._build_iter (feature_select , idx )
6061
6162 # 预测
62- self ._train_praba = self ._predict_tree (self .dataset .train_X , LRMessage .PREDICT_LEAF_MASK .value )
63+ self ._train_praba = self ._predict_tree (
64+ self .dataset .train_X , LRMessage .PREDICT_LEAF_MASK .value )
6365 # print('train_praba', set(self._train_praba))
6466
6567 # 评估
@@ -69,10 +71,11 @@ def fit(
6971 self .log .info (
7072 f'task { self .ctx .task_id } : iter-{ self ._iter_id } , auc: { auc } .' )
7173 self .log .info (f'task { self .ctx .task_id } : Ending iter-{ self ._iter_id } , '
72- f'time_costs: { time .time () - start_time } s.' )
74+ f'time_costs: { time .time () - start_time } s.' )
7375
7476 # 预测验证集
75- self ._test_praba = self ._predict_tree (self .dataset .test_X , LRMessage .TEST_LEAF_MASK .value )
77+ self ._test_praba = self ._predict_tree (
78+ self .dataset .test_X , LRMessage .TEST_LEAF_MASK .value )
7679 if not self .params .silent and self .dataset .test_y is not None :
7780 auc = Evaluation .fevaluation (
7881 self .dataset .test_y , self ._test_praba )['auc' ]
@@ -89,7 +92,8 @@ def predict(self, dataset: SecureDataset = None) -> np.ndarray:
8992 if dataset is None :
9093 dataset = self .dataset
9194
92- test_praba = self ._predict_tree (dataset .test_X , LRMessage .VALID_LEAF_MASK .value )
95+ test_praba = self ._predict_tree (
96+ dataset .test_X , LRMessage .VALID_LEAF_MASK .value )
9397 self ._test_praba = test_praba
9498
9599 if dataset .test_y is not None :
@@ -139,8 +143,10 @@ def _build_iter(self, feature_select, idx):
139143 public_key_list , d_other_list , partner_index_list = self ._receive_d_instance_list ()
140144 deriv = self ._calculate_deriv (x_ , d , partner_index_list , d_other_list )
141145
142- self ._train_weights -= self .params .learning_rate * deriv .astype ('float' )
143- self ._train_weights [~ np .isin (np .arange (len (self ._train_weights )), feature_select )] = 0
146+ self ._train_weights -= self .params .learning_rate * \
147+ deriv .astype ('float' )
148+ self ._train_weights [~ np .isin (
149+ np .arange (len (self ._train_weights )), feature_select )] = 0
144150
145151 def _predict_tree (self , X , key_type ):
146152 train_g = self ._loss_func .dot_product (X , self ._train_weights )
0 commit comments