Skip to content

Commit 12fe2fc

Browse files
authored
Fix federated learning demos and tests (dmlc#9488)
1 parent b2e93d2 commit 12fe2fc

File tree

10 files changed

+60
-11
lines changed

10 files changed

+60
-11
lines changed

demo/nvflare/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
!config
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"format_version": 2,
3+
"executors": [
4+
{
5+
"tasks": [
6+
"train"
7+
],
8+
"executor": {
9+
"path": "trainer.XGBoostTrainer",
10+
"args": {
11+
"server_address": "localhost:9091",
12+
"world_size": 2,
13+
"server_cert_path": "server-cert.pem",
14+
"client_key_path": "client-key.pem",
15+
"client_cert_path": "client-cert.pem",
16+
"use_gpus": false
17+
}
18+
}
19+
}
20+
],
21+
"task_result_filters": [],
22+
"task_data_filters": []
23+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"format_version": 2,
3+
"server": {
4+
"heart_beat_timeout": 600
5+
},
6+
"task_data_filters": [],
7+
"task_result_filters": [],
8+
"workflows": [
9+
{
10+
"id": "server_workflow",
11+
"path": "controller.XGBoostController",
12+
"args": {
13+
"port": 9091,
14+
"world_size": 2,
15+
"server_key_path": "server-key.pem",
16+
"server_cert_path": "server-cert.pem",
17+
"client_cert_path": "client-cert.pem"
18+
}
19+
}
20+
],
21+
"components": []
22+
}

demo/nvflare/horizontal/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ This directory contains a demo of Horizontal Federated Learning using
66
## Training with CPU only
77

88
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
9-
[README](../../plugin/federated/README.md)).
9+
[README](../../../plugin/federated/README.md)).
1010

1111
Install NVFlare (note that currently NVFlare only supports Python 3.8):
1212
```shell

demo/nvflare/horizontal/prepare_data.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ split -n l/${world_size} --numeric-suffixes=1 -a 1 ../../data/agaricus.txt.test
1616

1717
nvflare poc -n 2 --prepare
1818
mkdir -p /tmp/nvflare/poc/admin/transfer/horizontal-xgboost
19-
cp -fr config custom /tmp/nvflare/poc/admin/transfer/horizontal-xgboost
19+
cp -fr ../config custom /tmp/nvflare/poc/admin/transfer/horizontal-xgboost
2020
cp server-*.pem client-cert.pem /tmp/nvflare/poc/server/
2121
for (( site=1; site<=world_size; site++ )); do
2222
cp server-cert.pem client-*.pem /tmp/nvflare/poc/site-"$site"/

demo/nvflare/vertical/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ This directory contains a demo of Vertical Federated Learning using
66
## Training with CPU only
77

88
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
9-
[README](../../plugin/federated/README.md)).
9+
[README](../../../plugin/federated/README.md)).
1010

1111
Install NVFlare (note that currently NVFlare only supports Python 3.8):
1212
```shell

demo/nvflare/vertical/custom/trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class SupportedTasks(object):
1616

1717
class XGBoostTrainer(Executor):
1818
def __init__(self, server_address: str, world_size: int, server_cert_path: str,
19-
client_key_path: str, client_cert_path: str):
19+
client_key_path: str, client_cert_path: str, use_gpus: bool):
2020
"""Trainer for federated XGBoost.
2121
2222
Args:
@@ -32,6 +32,7 @@ def __init__(self, server_address: str, world_size: int, server_cert_path: str,
3232
self._server_cert_path = server_cert_path
3333
self._client_key_path = client_key_path
3434
self._client_cert_path = client_cert_path
35+
self._use_gpus = use_gpus
3536

3637
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
3738
abort_signal: Signal) -> Shareable:
@@ -81,6 +82,8 @@ def _do_training(self, fl_ctx: FLContext):
8182
'objective': 'binary:logistic',
8283
'eval_metric': 'auc',
8384
}
85+
if self._use_gpus:
86+
self.log_info(fl_ctx, 'GPUs are not currently supported by vertical federated XGBoost')
8487

8588
# specify validations set to watch performance
8689
watchlist = [(dtest, "eval"), (dtrain, "train")]

demo/nvflare/vertical/prepare_data.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ fi
5656

5757
nvflare poc -n 2 --prepare
5858
mkdir -p /tmp/nvflare/poc/admin/transfer/vertical-xgboost
59-
cp -fr config custom /tmp/nvflare/poc/admin/transfer/vertical-xgboost
59+
cp -fr ../config custom /tmp/nvflare/poc/admin/transfer/vertical-xgboost
6060
cp server-*.pem client-cert.pem /tmp/nvflare/poc/server/
6161
for (( site=1; site<=world_size; site++ )); do
6262
cp server-cert.pem client-*.pem /tmp/nvflare/poc/site-"${site}"/

tests/test_distributed/test_federated/runtests-federated.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out se
1111
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
1212

1313
# Split train and test files manually to simulate a federated environment.
14-
split -n l/"${world_size}" -d ../../demo/data/agaricus.txt.train agaricus.txt.train-
15-
split -n l/"${world_size}" -d ../../demo/data/agaricus.txt.test agaricus.txt.test-
14+
split -n l/"${world_size}" -d ../../../demo/data/agaricus.txt.train agaricus.txt.train-
15+
split -n l/"${world_size}" -d ../../../demo/data/agaricus.txt.test agaricus.txt.test-
1616

1717
python test_federated.py "${world_size}"

tests/test_distributed/test_federated/test_federated.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu:
3535
# Always call this before using distributed module
3636
with xgb.collective.CommunicatorContext(**communicator_env):
3737
# Load file, file will not be sharded in federated mode.
38-
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
39-
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
38+
dtrain = xgb.DMatrix('agaricus.txt.train-%02d?format=libsvm' % rank)
39+
dtest = xgb.DMatrix('agaricus.txt.test-%02d?format=libsvm' % rank)
4040

4141
# Specify parameters via map, definition are same as c++ version
4242
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
4343
if with_gpu:
44-
param['tree_method'] = 'gpu_hist'
45-
param['gpu_id'] = rank
44+
param['tree_method'] = 'hist'
45+
param['device'] = f"cuda:{rank}"
4646

4747
# Specify validations set to watch performance
4848
watchlist = [(dtest, 'eval'), (dtrain, 'train')]

0 commit comments

Comments
 (0)