Skip to content

Commit 63b49f3

Browse files
authored
[backport] Allow blocking launch of federated tracker. (dmlc#10414) (dmlc#10425)
1 parent 6094106 commit 63b49f3

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

plugin/example/custom_obj.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class MyLogistic : public ObjFunction {
6969

7070
void SaveConfig(Json* p_out) const override {
7171
auto& out = *p_out;
72-
out["name"] = String("my_logistic");
72+
out["name"] = String("mylogistic");
7373
out["my_logistic_param"] = ToJson(param_);
7474
}
7575

python-package/xgboost/federated.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,19 @@ def run_federated_server( # pylint: disable=too-many-arguments
6565
server_key_path: Optional[str] = None,
6666
server_cert_path: Optional[str] = None,
6767
client_cert_path: Optional[str] = None,
68+
blocking: bool = True,
6869
timeout: int = 300,
69-
) -> Dict[str, Any]:
70-
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info."""
70+
) -> Optional[Dict[str, Any]]:
71+
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info.
72+
73+
Parameters
74+
----------
75+
blocking :
76+
Block the server until the training is finished. If set to False, the function
77+
launches an additional thread and returns the worker arguments. The default is
78+
True and a higher level framework is responsible for setting worker parameters.
79+
80+
"""
7181
args: Dict[str, Any] = {"n_workers": n_workers}
7282
secure = all(
7383
path is not None
@@ -78,6 +88,10 @@ def run_federated_server( # pylint: disable=too-many-arguments
7888
)
7989
tracker.start()
8090

91+
if blocking:
92+
tracker.wait_for()
93+
return None
94+
8195
thread = Thread(target=tracker.wait_for)
8296
thread.daemon = True
8397
thread.start()

tests/python/test_collective.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_federated_communicator():
6363
world_size = 2
6464
tracker = multiprocessing.Process(
6565
target=federated.run_federated_server,
66-
kwargs={"port": port, "n_workers": world_size},
66+
kwargs={"port": port, "n_workers": world_size, "blocking": False},
6767
)
6868
tracker.start()
6969
if not tracker.is_alive():

0 commit comments

Comments
 (0)