Skip to content

Commit 3701d11

Browse files
authored
Merge pull request #59 from Xtra-Computing/layerwise
Layerwise
2 parents 6a9a873 + 82da2f3 commit 3701d11

File tree

11 files changed

+482
-69
lines changed

11 files changed

+482
-69
lines changed

include/FedTree/Tree/function_builder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class FunctionBuilder {
1616
public:
1717
virtual vector<Tree> build_approximate(const SyncArray<GHPair> &gradients, bool update_y_predict = true) = 0;
1818

19+
virtual vector<Tree> build_a_subtree_approximate(const SyncArray<GHPair> &gradients, int n_layer) = 0;
20+
1921
virtual Tree get_tree()= 0;
2022

2123
virtual void set_tree(Tree tree) = 0;

include/FedTree/Tree/gbdt.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,16 @@ class GBDT {
2020

2121
void train(GBDTParam &param, DataSet &dataset);
2222

23+
void train_a_subtree(GBDTParam &param, DataSet &dataset, int n_layer, int *id_list, int *nins_list, float *gradient_g_list, float *gradient_h_list, int *n_node, int *nodeid_list, float *input_gradient_g, float *input_gradient_h);
24+
2325
vector<float_type> predict(const GBDTParam &model_param, const DataSet &dataSet);
2426

2527
vector<float_type> predict(const GBDTParam &model_param, const vector<DataSet> &dataSet);
2628

2729
void predict_raw(const GBDTParam &model_param, const DataSet &dataSet, SyncArray<float_type> &y_predict);
2830

31+
void predict_leaf(const GBDTParam &model_param, const DataSet &dataSet, SyncArray<float_type> &y_predict, int *ins2leaf);
32+
2933
void predict_raw_vertical(const GBDTParam &model_param, const DataSet &dataSet, SyncArray<float_type> &y_predict, std::map<int, vector<int>> &batch_idxs);
3034

3135
void predict_raw_vertical(const GBDTParam &model_param, const vector<DataSet> &dataSet, SyncArray<float_type> &y_predict);

include/FedTree/Tree/tree_builder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class TreeBuilder : public FunctionBuilder{
2323

2424
vector<Tree> build_approximate(const SyncArray<GHPair> &gradients, bool update_y_predict = true) override;
2525

26+
vector<Tree> build_a_subtree_approximate(const SyncArray<GHPair> &gradients, int n_layer) override;
27+
2628
void build_tree_by_predefined_structure(const SyncArray<GHPair> &gradients, vector<Tree> &trees);
2729

2830
void build_init(const GHPair sum_gh, int k) override;

include/FedTree/booster.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class Booster {
4343

4444
void boost(vector<vector<Tree>> &boosted_model);
4545

46+
void boost_a_subtree(vector<vector<Tree>> &trees, int n_layer, int *id_list, int *nins_list, float *gradient_g_list,
47+
float *gradient_h_list, int *n_node, int *nodeid_list, float *input_gradient_g, float *input_gradient_h);
48+
4649
void boost_without_prediction(vector<vector<Tree>> &boosted_model);
4750

4851
GBDTParam param;

python/fedtree/fedtree.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def fit(self, X, y, groups=None):
9696
self.model = None
9797
sparse = sp.issparse(X)
9898
if sparse is False:
99+
# potential bug: csr_matrix ignores all zero values in X
99100
X = sp.csr_matrix(X)
100101
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
101102

@@ -182,6 +183,52 @@ def predict(self, X, groups=None):
182183
predict_label = [self.predict_label_ptr[index] for index in range(0, X.shape[0])]
183184
self.predict_label = np.asarray(predict_label)
184185
return self.predict_label
186+
187+
def predict_leaf(self, X, groups=None):
188+
if self.model is None:
189+
print("Please train the model first or load model from file!")
190+
raise ValueError
191+
sparse = sp.isspmatrix(X)
192+
if sparse is False:
193+
X = sp.csr_matrix(X)
194+
X.data = np.asarray(X.data, dtype=np.float32, order='C')
195+
X.sort_indices()
196+
data = X.data.ctypes.data_as(POINTER(c_float))
197+
indices = X.indices.ctypes.data_as(POINTER(c_int32))
198+
indptr = X.indptr.ctypes.data_as(POINTER(c_int32))
199+
if(self.objective != 'multi:softprob'):
200+
self.predict_label_ptr = (c_float * X.shape[0])()
201+
else:
202+
temp_size = X.shape[0] * self.num_class
203+
self.predict_label_ptr = (c_float * temp_size)()
204+
if self.group_label is not None:
205+
group_label = (c_float * len(self.group_label))()
206+
group_label[:] = self.group_label
207+
else:
208+
group_label = None
209+
in_groups, num_groups = self._construct_groups(groups)
210+
ins2leaf_c = (c_int32 * (X.shape[0] * self.n_trees))()
211+
fedtree.predict_leaf(
212+
X.shape[0],
213+
data,
214+
indptr,
215+
indices,
216+
self.predict_label_ptr,
217+
byref(self.model),
218+
self.n_trees,
219+
self.tree_per_iter,
220+
self.objective.encode('utf-8'),
221+
self.num_class,
222+
c_float(self.learning_rate),
223+
group_label,
224+
in_groups,
225+
ins2leaf_c,
226+
num_groups, self.verbose, self.bagging,
227+
)
228+
self.ins2leaf = np.array([ins2leaf_c[i] for i in range(X.shape[0] * self.n_trees)])
229+
# predict_label = [self.predict_label_ptr[index] for index in range(0, X.shape[0])]
230+
# self.predict_label = np.asarray(predict_label)
231+
return self.ins2leaf
185232

186233
def predict_proba(self, X, groups=None):
187234
if self.model is None:
@@ -235,7 +282,6 @@ def predict_proba(self, X, groups=None):
235282
return self.predict_proba
236283

237284

238-
239285
def save_model(self, model_path):
240286
if self.model is None:
241287
print("Please train the model first or load model from file!")
@@ -350,6 +396,75 @@ def cv(self, X, y, folds=None, nfold=5, shuffle=True, seed=0):
350396
print("mean test RMSE:%.6f+%.6f" %(statistics.mean(test_score_list), statistics.stdev(test_score_list)))
351397
return self.eval_res
352398

399+
def centralize_train_a_subtree(self, X, y, n_layer, input_gradient_g = None, input_gradient_h = None, groups=None):
400+
n_ins = len(X)
401+
if self.model is not None:
402+
fedtree.model_free(byref(self.model))
403+
self.model = None
404+
sparse = sp.issparse(X)
405+
if sparse is False:
406+
# potential bug: csr_matrix ignores all zero values in X
407+
X = sp.csr_matrix(X)
408+
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
409+
410+
X.data = np.asarray(X.data, dtype=np.float32, order='C')
411+
X.sort_indices()
412+
data = X.data.ctypes.data_as(POINTER(c_float))
413+
indices = X.indices.ctypes.data_as(POINTER(c_int32))
414+
indptr = X.indptr.ctypes.data_as(POINTER(c_int32))
415+
y = np.asarray(y, dtype=np.float32, order='C')
416+
label = y.ctypes.data_as(POINTER(c_float))
417+
in_groups, num_groups = self._construct_groups(groups)
418+
group_label = (c_float * len(set(y)))()
419+
n_class = (c_int * 1)()
420+
n_class[0] = self.num_class
421+
tree_per_iter_ptr = (c_int * 1)()
422+
self.model = (c_long * 1)()
423+
n_max_node = pow(2, n_layer)
424+
# needs to represent instance ID as int
425+
insid_list = (c_int * n_ins)()
426+
n_ins_list = (c_int * n_max_node)()
427+
gradient_g_list = (c_float * n_ins)()
428+
gradient_h_list = (c_float * n_ins)()
429+
n_node = (c_int * 1)()
430+
nodeid_list = (c_int * n_max_node)()
431+
input_gradient_g = np.asarray(input_gradient_g, dtype=np.float32, order='C')
432+
input_g = input_gradient_g.ctypes.data_as(POINTER(c_float))
433+
input_gradient_h = np.asarray(input_gradient_h, dtype=np.float32, order='C')
434+
input_h = input_gradient_h.ctypes.data_as(POINTER(c_float))
435+
fedtree.centralize_train_a_subtree(c_float(self.variance), c_float(self.privacy_budget),
436+
self.max_depth, self.n_trees, c_float(self.min_child_weight), c_float(self.lambda_ft), c_float(self.gamma), c_float(self.column_sampling_rate),
437+
self.verbose, self.bagging, self.n_parallel_trees, c_float(self.learning_rate), self.objective.encode('utf-8'), n_class, self.n_device, self.max_num_bin,
438+
self.seed, c_float(self.ins_bagging_fraction), self.reorder_label, c_float(self.constant_h),
439+
X.shape[0], data, indptr, indices, label, self.tree_method, byref(self.model), tree_per_iter_ptr, group_label,
440+
in_groups, num_groups, n_layer, insid_list, n_ins_list, gradient_g_list, gradient_h_list, n_node, nodeid_list, input_g, input_h)
441+
self.num_class = n_class[0]
442+
self.tree_per_iter = tree_per_iter_ptr[0]
443+
self.group_label = [group_label[idx] for idx in range(len(set(y)))]
444+
445+
self.insid_list = [insid_list[i] for i in range(n_ins)]
446+
self.n_ins_list = [n_ins_list[i] for i in range(n_node[0])]
447+
self.gradient_g_list = [gradient_g_list[i] for i in range(n_ins)]
448+
self.gradient_h_list = [gradient_h_list[i] for i in range(n_ins)]
449+
self.n_node = n_node[0]
450+
self.nodeid_list = [nodeid_list[i] for i in range(n_node[0])]
451+
if self.model is None:
452+
print("The model returned is empty!")
453+
exit()
454+
455+
return self
456+
457+
def update_a_layer_cpp(self, X, ins, nins, gradient_g, gradient_h, n_node, lamb):
458+
c_x = np.asarray(X, dtype=np.int32).data.ctypes.data_as(POINTER(c_int32))
459+
c_ins = np.asarray(ins, dtype=np.int32).data.ctypes.data_as(POINTER(c_int32))
460+
c_nins = np.asarray(nins, dtype=np.int32).data.ctypes.data_as(POINTER(c_int32))
461+
c_gradient_g = np.asarray(gradient_g, dtype=np.float32).data.ctypes.data_as(POINTER(c_float))
462+
c_gradient_h = np.asarray(gradient_h, dtype=np.float32).data.ctypes.data_as(POINTER(c_float))
463+
leaf_val = (c_float * (n_node*2))()
464+
fedtree.update_a_layer_with_flag(c_x, c_ins, c_nins, c_gradient_g, c_gradient_h, n_node, leaf_val)
465+
self.leaf_val = [leaf_val[i] for i in range(len(n_node*2))]
466+
467+
353468
class FLClassifier(FLModel, fedtreeClassifierBase):
354469
_impl = 'classifier'
355470

python/setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
if not path.exists(path.join(dirname, "fedtree", path.basename(lib_path))):
1919
copyfile(lib_path, path.join(dirname, "fedtree", path.basename(lib_path)))
2020

21+
# lib_path = "./fedtree/libFedTree.so"
22+
23+
2124
setuptools.setup(name="fedtree",
22-
version="1.0.4",
25+
version="1.0.5",
2326
packages=["fedtree"],
2427
package_dir={"python": "fedtree"},
2528
description="A federated learning library for trees",

src/FedTree/Tree/gbdt.cpp

Lines changed: 129 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,50 @@ void GBDT::train(GBDTParam &param, DataSet &dataset) {
3939
// float_type score = predict_score(param, dataset);
4040
// LOG(INFO) << score;
4141

42+
auto stop = timer.now();
43+
std::chrono::duration<float> training_time = stop - start;
44+
LOG(INFO) << "training time = " << training_time.count();
45+
return;
46+
}
47+
48+
void GBDT::train_a_subtree(GBDTParam &param, DataSet &dataset, int n_layer, int *id_list, int *nins_list, float *gradient_g_list,
49+
float *gradient_h_list, int *n_node, int *node_id_list, float *input_gradient_g, float *input_gradient_h) {
50+
if (param.tree_method == "auto")
51+
param.tree_method = "hist";
52+
else if (param.tree_method != "hist") {
53+
std::cout << "FedTree only supports histogram-based training yet";
54+
exit(1);
55+
}
56+
57+
if (param.objective.find("multi:") != std::string::npos || param.objective.find("binary:") != std::string::npos) {
58+
int num_class = dataset.label.size();
59+
if (param.num_class != num_class) {
60+
LOG(INFO) << "updating number of classes from " << param.num_class << " to " << num_class;
61+
param.num_class = num_class;
62+
}
63+
if (param.num_class > 2)
64+
param.tree_per_round = param.num_class;
65+
} else if (param.objective.find("reg:") != std::string::npos) {
66+
param.num_class = 1;
67+
}
68+
69+
Booster booster;
70+
booster.init(dataset, param);
71+
std::chrono::high_resolution_clock timer;
72+
auto start = timer.now();
73+
std::cout<<"start boost a subtree"<<std::endl;
74+
booster.boost_a_subtree(trees, n_layer, id_list, nins_list, gradient_g_list, gradient_h_list, n_node, node_id_list, input_gradient_g, input_gradient_h);
75+
//booster.boost(trees);
76+
// float_type score = predict_score(param, dataset);
77+
// LOG(INFO) << score;
4278

4379
auto stop = timer.now();
4480
std::chrono::duration<float> training_time = stop - start;
4581
LOG(INFO) << "training time = " << training_time.count();
4682
return;
4783
}
4884

85+
4986
vector<float_type> GBDT::predict(const GBDTParam &model_param, const DataSet &dataSet) {
5087
SyncArray<float_type> y_predict;
5188
predict_raw(model_param, dataSet, y_predict);
@@ -157,9 +194,7 @@ void GBDT::predict_raw(const GBDTParam &model_param, const DataSet &dataSet, Syn
157194
int num_node = trees[0][0].nodes.size();
158195

159196
int total_num_node = num_iter * num_class * num_node;
160-
//TODO: reduce the output size for binary classification
161197
y_predict.resize(n_instances * num_class);
162-
163198
SyncArray<Tree::TreeNode> model(total_num_node);
164199
auto model_data = model.host_data();
165200
int tree_cnt = 0;
@@ -462,4 +497,95 @@ void GBDT::predict_raw_vertical(const GBDTParam &model_param, const vector<DataS
462497
predict_data_class[iid] += sum;
463498
}//end all tree prediction
464499
}
465-
}
500+
}
501+
502+
void GBDT::predict_leaf(const GBDTParam &model_param, const DataSet &dataSet, SyncArray<float_type> &y_predict, int *ins2leaf) {
503+
TIMED_SCOPE(timerObj, "predict");
504+
int n_instances = dataSet.n_instances();
505+
// int n_features = dataSet.n_features();
506+
507+
//the whole model to an array
508+
int num_iter = trees.size();
509+
int num_class = trees.front().size();
510+
int num_node = trees[0][0].nodes.size();
511+
512+
int total_num_node = num_iter * num_class * num_node;
513+
// y_predict.resize(n_instances * num_class);
514+
std::cout<<"num_class in predict_raw:"<<num_class<<std::endl;
515+
SyncArray<Tree::TreeNode> model(total_num_node);
516+
auto model_data = model.host_data();
517+
int tree_cnt = 0;
518+
for (auto &vtree:trees) {
519+
for (auto &t:vtree) {
520+
memcpy(model_data + num_node * tree_cnt, t.nodes.host_data(), sizeof(Tree::TreeNode) * num_node);
521+
tree_cnt++;
522+
}
523+
}
524+
525+
PERFORMANCE_CHECKPOINT_WITH_ID(timerObj, "init trees");
526+
527+
//do prediction
528+
auto model_host_data = model.host_data();
529+
// auto predict_data = y_predict.host_data();
530+
auto csr_col_idx_data = dataSet.csr_col_idx.data();
531+
auto csr_val_data = dataSet.csr_val.data();
532+
auto csr_row_ptr_data = dataSet.csr_row_ptr.data();
533+
auto lr = model_param.learning_rate;
534+
PERFORMANCE_CHECKPOINT_WITH_ID(timerObj, "copy data");
535+
536+
#pragma omp parallel for
537+
for (int iid = 0; iid < n_instances; iid++) {
538+
auto get_next_child = [&](Tree::TreeNode node, float_type feaValue) {
539+
//return feaValue < node.split_value ? node.lch_index : node.rch_index;
540+
return (feaValue - node.split_value) >= -1e-6 ? node.rch_index : node.lch_index;
541+
};
542+
auto get_val = [&](const int *row_idx, const float_type *row_val, int row_len, int idx,
543+
bool *is_missing) -> float_type {
544+
//binary search to get feature value
545+
const int *left = row_idx;
546+
const int *right = row_idx + row_len;
547+
548+
while (left != right) {
549+
const int *mid = left + (right - left) / 2;
550+
if (*mid == idx) {
551+
*is_missing = false;
552+
return row_val[mid - row_idx];
553+
}
554+
if (*mid > idx)
555+
right = mid;
556+
else left = mid + 1;
557+
}
558+
*is_missing = true;
559+
return 0;
560+
};
561+
const int *col_idx = csr_col_idx_data + csr_row_ptr_data[iid];
562+
const float_type *row_val = csr_val_data + csr_row_ptr_data[iid];
563+
int row_len = csr_row_ptr_data[iid + 1] - csr_row_ptr_data[iid];
564+
for (int t = 0; t < num_class; t++) {
565+
// auto predict_data_class = predict_data + t * n_instances;
566+
// float_type sum = 0;
567+
for (int iter = 0; iter < num_iter; iter++) {
568+
const Tree::TreeNode *node_data = model_host_data + iter * num_class * num_node + t * num_node;
569+
Tree::TreeNode curNode = node_data[0];
570+
int cur_nid = 0; //node id
571+
while (!curNode.is_leaf) {
572+
int fid = curNode.split_feature_id;
573+
bool is_missing;
574+
float_type fval = get_val(col_idx, row_val, row_len, fid, &is_missing);
575+
if (!is_missing)
576+
cur_nid = get_next_child(curNode, fval);
577+
else if (curNode.default_right)
578+
cur_nid = curNode.rch_index;
579+
else
580+
cur_nid = curNode.lch_index;
581+
582+
curNode = node_data[cur_nid];
583+
}
584+
ins2leaf[iter * n_instances + iid] = cur_nid;
585+
// sum += lr * curNode.base_weight;
586+
}
587+
// if (model_param.bagging)
588+
// sum /= num_iter;
589+
}//end all tree prediction
590+
}
591+
}

0 commit comments

Comments
 (0)