Skip to content

Commit 8b62f92

Browse files
committed
fix issue #40
1 parent 00f0416 commit 8b62f92

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
copyfile(lib_path, path.join(dirname, "fedtree", path.basename(lib_path)))
2020

2121
setuptools.setup(name="fedtree",
22-
version="1.0.3",
22+
version="1.0.4",
2323
packages=["fedtree"],
2424
package_dir={"python": "fedtree"},
2525
description="A federated learning library for trees",

src/FedTree/scikit_fedtree.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ extern "C" {
234234
test_dataset.label.emplace_back(group_label[i]);
235235
}
236236
}
237+
else{
238+
for (int i = 0; i < num_class; ++i) {
239+
test_dataset.label.emplace_back(i);
240+
}
241+
}
237242
// predict
238243
SyncArray<float_type> y_predict;
239244
vector<vector<Tree>> boosted_model_in_mem;
@@ -274,8 +279,16 @@ extern "C" {
274279
test_dataset.load_from_sparse(row_size, val, row_ptr, col_ptr, NULL, group, num_group, model_param);
275280
set_logger(verbose);
276281
test_dataset.label.clear();
277-
for (int i = 0; i < num_class; ++i) {
278-
test_dataset.label.emplace_back(group_label[i]);
282+
if(group_label != NULL) {
283+
test_dataset.label.clear();
284+
for (int i = 0; i < num_class; ++i) {
285+
test_dataset.label.emplace_back(group_label[i]);
286+
}
287+
}
288+
else{
289+
for (int i = 0; i < num_class; ++i) {
290+
test_dataset.label.emplace_back(i);
291+
}
279292
}
280293
// predict
281294
SyncArray<float_type> y_predict;

0 commit comments

Comments
 (0)