Skip to content

Commit 644e001

Browse files
authored
[backport][fed] Fixes for the encrypted GRPC backend. (dmlc#10503) (dmlc#10577)
1 parent 3f47fcb commit 644e001

File tree

9 files changed

+192
-109
lines changed

9 files changed

+192
-109
lines changed

plugin/federated/federated_comm.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023, XGBoost contributors
2+
* Copyright 2023-2024, XGBoost contributors
33
*/
44
#include "federated_comm.h"
55

@@ -11,6 +11,7 @@
1111
#include <string> // for string, stoi
1212

1313
#include "../../src/common/common.h" // for Split
14+
#include "../../src/common/io.h" // for ReadAll
1415
#include "../../src/common/json_utils.h" // for OptionalArg
1516
#include "xgboost/json.h" // for Json
1617
#include "xgboost/logging.h"
@@ -46,9 +47,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
4647
} else {
4748
stub_ = [&] {
4849
grpc::SslCredentialsOptions options;
49-
options.pem_root_certs = server_cert;
50-
options.pem_private_key = client_key;
51-
options.pem_cert_chain = client_cert;
50+
options.pem_root_certs = common::ReadAll(server_cert);
51+
options.pem_private_key = common::ReadAll(client_key);
52+
options.pem_cert_chain = common::ReadAll(client_cert);
5253
grpc::ChannelArguments args;
5354
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
5455
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),

python-package/xgboost/federated.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def __init__( # pylint: disable=R0913, W0231
3939
n_workers: int,
4040
port: int,
4141
secure: bool,
42-
server_key_path: str = "",
43-
server_cert_path: str = "",
44-
client_cert_path: str = "",
42+
server_key_path: Optional[str] = None,
43+
server_cert_path: Optional[str] = None,
44+
client_cert_path: Optional[str] = None,
4545
timeout: int = 300,
4646
) -> None:
4747
handle = ctypes.c_void_p()
@@ -84,7 +84,13 @@ def run_federated_server( # pylint: disable=too-many-arguments
8484
for path in [server_key_path, server_cert_path, client_cert_path]
8585
)
8686
tracker = FederatedTracker(
87-
n_workers=n_workers, port=port, secure=secure, timeout=timeout
87+
n_workers=n_workers,
88+
port=port,
89+
secure=secure,
90+
timeout=timeout,
91+
server_key_path=server_key_path,
92+
server_cert_path=server_cert_path,
93+
client_cert_path=client_cert_path,
8894
)
8995
tracker.start()
9096

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# pylint: disable=unbalanced-tuple-unpacking, too-many-locals
2+
"""Tests for federated learning."""
3+
4+
import multiprocessing
5+
import os
6+
import subprocess
7+
import tempfile
8+
import time
9+
from typing import List, cast
10+
11+
from sklearn.datasets import dump_svmlight_file, load_svmlight_file
12+
from sklearn.model_selection import train_test_split
13+
14+
import xgboost as xgb
15+
import xgboost.federated
16+
from xgboost import testing as tm
17+
from xgboost.training import TrainingCallback
18+
19+
SERVER_KEY = "server-key.pem"
20+
SERVER_CERT = "server-cert.pem"
21+
CLIENT_KEY = "client-key.pem"
22+
CLIENT_CERT = "client-cert.pem"
23+
24+
25+
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
26+
"""Run federated server for test."""
27+
if with_ssl:
28+
xgboost.federated.run_federated_server(
29+
world_size,
30+
port,
31+
server_key_path=SERVER_KEY,
32+
server_cert_path=SERVER_CERT,
33+
client_cert_path=CLIENT_CERT,
34+
)
35+
else:
36+
xgboost.federated.run_federated_server(world_size, port)
37+
38+
39+
def run_worker(
40+
port: int, world_size: int, rank: int, with_ssl: bool, device: str
41+
) -> None:
42+
"""Run federated client worker for test."""
43+
communicator_env = {
44+
"dmlc_communicator": "federated",
45+
"federated_server_address": f"localhost:{port}",
46+
"federated_world_size": world_size,
47+
"federated_rank": rank,
48+
}
49+
if with_ssl:
50+
communicator_env["federated_server_cert_path"] = SERVER_CERT
51+
communicator_env["federated_client_key_path"] = CLIENT_KEY
52+
communicator_env["federated_client_cert_path"] = CLIENT_CERT
53+
54+
cpu_count = os.cpu_count()
55+
assert cpu_count is not None
56+
n_threads = cpu_count // world_size
57+
58+
# Always call this before using distributed module
59+
with xgb.collective.CommunicatorContext(**communicator_env):
60+
# Load file, file will not be sharded in federated mode.
61+
X, y = load_svmlight_file(f"agaricus.txt-{rank}.train")
62+
dtrain = xgb.DMatrix(X, y)
63+
X, y = load_svmlight_file(f"agaricus.txt-{rank}.test")
64+
dtest = xgb.DMatrix(X, y)
65+
66+
# Specify parameters via map, definition are same as c++ version
67+
param = {
68+
"max_depth": 2,
69+
"eta": 1,
70+
"objective": "binary:logistic",
71+
"nthread": n_threads,
72+
"tree_method": "hist",
73+
"device": device,
74+
}
75+
76+
# Specify validations set to watch performance
77+
watchlist = [(dtest, "eval"), (dtrain, "train")]
78+
num_round = 20
79+
80+
# Run training, all the features in training API is available.
81+
results: TrainingCallback.EvalsLog = {}
82+
bst = xgb.train(
83+
param,
84+
dtrain,
85+
num_round,
86+
evals=watchlist,
87+
early_stopping_rounds=2,
88+
evals_result=results,
89+
)
90+
assert tm.non_increasing(cast(List[float], results["train"]["logloss"]))
91+
assert tm.non_increasing(cast(List[float], results["eval"]["logloss"]))
92+
93+
# save the model, only ask process 0 to save the model.
94+
if xgb.collective.get_rank() == 0:
95+
with tempfile.TemporaryDirectory() as tmpdir:
96+
bst.save_model(os.path.join(tmpdir, "model.json"))
97+
xgb.collective.communicator_print("Finished training\n")
98+
99+
100+
def run_federated(world_size: int, with_ssl: bool, use_gpu: bool) -> None:
101+
"""Launcher for clients and the server."""
102+
port = 9091
103+
104+
server = multiprocessing.Process(
105+
target=run_server, args=(port, world_size, with_ssl)
106+
)
107+
server.start()
108+
time.sleep(1)
109+
if not server.is_alive():
110+
raise ValueError("Error starting Federated Learning server")
111+
112+
workers = []
113+
for rank in range(world_size):
114+
device = f"cuda:{rank}" if use_gpu else "cpu"
115+
worker = multiprocessing.Process(
116+
target=run_worker, args=(port, world_size, rank, with_ssl, device)
117+
)
118+
workers.append(worker)
119+
worker.start()
120+
for worker in workers:
121+
worker.join()
122+
server.terminate()
123+
124+
125+
def run_federated_learning(with_ssl: bool, use_gpu: bool, test_path: str) -> None:
126+
"""Run federated learning tests."""
127+
n_workers = 2
128+
129+
if with_ssl:
130+
command = "openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout {part}-key.pem -out {part}-cert.pem -subj /C=US/CN=localhost" # pylint: disable=line-too-long
131+
server_key = command.format(part="server").split()
132+
subprocess.check_call(server_key)
133+
client_key = command.format(part="client").split()
134+
subprocess.check_call(client_key)
135+
136+
train_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.train")
137+
test_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.test")
138+
139+
X_train, y_train = load_svmlight_file(train_path)
140+
X_test, y_test = load_svmlight_file(test_path)
141+
142+
X0, X1, y0, y1 = train_test_split(X_train, y_train, test_size=0.5)
143+
X0_valid, X1_valid, y0_valid, y1_valid = train_test_split(
144+
X_test, y_test, test_size=0.5
145+
)
146+
147+
dump_svmlight_file(X0, y0, "agaricus.txt-0.train")
148+
dump_svmlight_file(X0_valid, y0_valid, "agaricus.txt-0.test")
149+
150+
dump_svmlight_file(X1, y1, "agaricus.txt-1.train")
151+
dump_svmlight_file(X1_valid, y1_valid, "agaricus.txt-1.test")
152+
153+
run_federated(world_size=n_workers, with_ssl=with_ssl, use_gpu=use_gpu)

src/context.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,11 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
191191
}
192192
if (device.IsCUDA()) {
193193
device = CUDAOrdinal(device, fail_on_invalid_gpu_id);
194+
if (!device.IsCUDA()) {
195+
// We allow loading a GPU-based pickle on a CPU-only machine.
196+
LOG(WARNING) << "XGBoost is not compiled with CUDA support.";
197+
}
194198
}
195-
196199
return device;
197200
}
198201
} // namespace

tests/ci_build/lint_python.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class LintersPaths:
3434
"tests/python/test_with_pandas.py",
3535
"tests/python-gpu/",
3636
"tests/python-sycl/",
37+
"tests/test_distributed/test_federated/",
38+
"tests/test_distributed/test_gpu_federated/",
3739
"tests/test_distributed/test_with_dask/",
3840
"tests/test_distributed/test_gpu_with_dask/",
3941
"tests/test_distributed/test_with_spark/",
@@ -94,6 +96,8 @@ class LintersPaths:
9496
"tests/python-gpu/load_pickle.py",
9597
"tests/python-gpu/test_gpu_training_continuation.py",
9698
"tests/python/test_model_io.py",
99+
"tests/test_distributed/test_federated/",
100+
"tests/test_distributed/test_gpu_federated/",
97101
"tests/test_distributed/test_with_spark/test_data.py",
98102
"tests/test_distributed/test_gpu_with_spark/test_data.py",
99103
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",

tests/ci_build/test_python.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ case "$suite" in
7070
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu
7171
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_dask
7272
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_spark
73+
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_federated
7374
unset_pyspark_envs
7475
uninstall_xgboost
7576
set +x
@@ -84,6 +85,7 @@ case "$suite" in
8485
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python
8586
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_dask
8687
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_spark
88+
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_federated
8789
unset_pyspark_envs
8890
uninstall_xgboost
8991
set +x

tests/test_distributed/test_federated/runtests-federated.sh

Lines changed: 0 additions & 17 deletions
This file was deleted.
Lines changed: 5 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,8 @@
1-
#!/usr/bin/python
2-
import multiprocessing
3-
import sys
4-
import time
1+
import pytest
52

6-
import xgboost as xgb
7-
import xgboost.federated
3+
from xgboost.testing.federated import run_federated_learning
84

9-
SERVER_KEY = 'server-key.pem'
10-
SERVER_CERT = 'server-cert.pem'
11-
CLIENT_KEY = 'client-key.pem'
12-
CLIENT_CERT = 'client-cert.pem'
135

14-
15-
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
16-
if with_ssl:
17-
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
18-
CLIENT_CERT)
19-
else:
20-
xgboost.federated.run_federated_server(port, world_size)
21-
22-
23-
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
24-
communicator_env = {
25-
'xgboost_communicator': 'federated',
26-
'federated_server_address': f'localhost:{port}',
27-
'federated_world_size': world_size,
28-
'federated_rank': rank
29-
}
30-
if with_ssl:
31-
communicator_env['federated_server_cert'] = SERVER_CERT
32-
communicator_env['federated_client_key'] = CLIENT_KEY
33-
communicator_env['federated_client_cert'] = CLIENT_CERT
34-
35-
# Always call this before using distributed module
36-
with xgb.collective.CommunicatorContext(**communicator_env):
37-
# Load file, file will not be sharded in federated mode.
38-
dtrain = xgb.DMatrix('agaricus.txt.train-%02d?format=libsvm' % rank)
39-
dtest = xgb.DMatrix('agaricus.txt.test-%02d?format=libsvm' % rank)
40-
41-
# Specify parameters via map, definition are same as c++ version
42-
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
43-
if with_gpu:
44-
param['tree_method'] = 'hist'
45-
param['device'] = f"cuda:{rank}"
46-
47-
# Specify validations set to watch performance
48-
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
49-
num_round = 20
50-
51-
# Run training, all the features in training API is available.
52-
bst = xgb.train(param, dtrain, num_round, evals=watchlist,
53-
early_stopping_rounds=2)
54-
55-
# Save the model, only ask process 0 to save the model.
56-
if xgb.collective.get_rank() == 0:
57-
bst.save_model("test.model.json")
58-
xgb.collective.communicator_print("Finished training\n")
59-
60-
61-
def run_federated(with_ssl: bool = True, with_gpu: bool = False) -> None:
62-
port = 9091
63-
world_size = int(sys.argv[1])
64-
65-
server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl))
66-
server.start()
67-
time.sleep(1)
68-
if not server.is_alive():
69-
raise Exception("Error starting Federated Learning server")
70-
71-
workers = []
72-
for rank in range(world_size):
73-
worker = multiprocessing.Process(target=run_worker,
74-
args=(port, world_size, rank, with_ssl, with_gpu))
75-
workers.append(worker)
76-
worker.start()
77-
for worker in workers:
78-
worker.join()
79-
server.terminate()
80-
81-
82-
if __name__ == '__main__':
83-
run_federated(with_ssl=True, with_gpu=False)
84-
run_federated(with_ssl=False, with_gpu=False)
85-
run_federated(with_ssl=True, with_gpu=True)
86-
run_federated(with_ssl=False, with_gpu=True)
6+
@pytest.mark.parametrize("with_ssl", [True, False])
7+
def test_federated_learning(with_ssl: bool) -> None:
8+
run_federated_learning(with_ssl, False, __file__)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import pytest
2+
3+
from xgboost.testing.federated import run_federated_learning
4+
5+
6+
@pytest.mark.parametrize("with_ssl", [True, False])
7+
@pytest.mark.mgpu
8+
def test_federated_learning(with_ssl: bool) -> None:
9+
run_federated_learning(with_ssl, True, __file__)

0 commit comments

Comments
 (0)