Skip to content

Commit 5b76acc

Browse files
authored
Add back xgboost.rabit for backwards compatibility (dmlc#8408) (dmlc#8411)
1 parent 4bc59ef commit 5b76acc

File tree

3 files changed

+200
-1
lines changed

3 files changed

+200
-1
lines changed

python-package/xgboost/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from . import tracker # noqa
7-
from . import collective, dask
7+
from . import collective, dask, rabit
88
from .core import (
99
Booster,
1010
DataIter,

python-package/xgboost/rabit.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Compatibility shim for xgboost.rabit; to be removed in 2.0"""
2+
import logging
3+
import warnings
4+
from enum import IntEnum, unique
5+
from typing import Any, TypeVar, Callable, Optional, List
6+
7+
import numpy as np
8+
9+
from . import collective
10+
11+
LOGGER = logging.getLogger("[xgboost.rabit]")
12+
13+
14+
def _deprecation_warning() -> str:
15+
return (
16+
"The xgboost.rabit submodule is marked as deprecated in 1.7 and will be removed "
17+
"in 2.0. Please use xgboost.collective instead."
18+
)
19+
20+
21+
def init(args: Optional[List[bytes]] = None) -> None:
22+
"""Initialize the rabit library with arguments"""
23+
warnings.warn(_deprecation_warning(), FutureWarning)
24+
parsed = {}
25+
if args:
26+
for arg in args:
27+
kv = arg.decode().split('=')
28+
if len(kv) == 2:
29+
parsed[kv[0]] = kv[1]
30+
collective.init(**parsed)
31+
32+
33+
def finalize() -> None:
34+
"""Finalize the process, notify tracker everything is done."""
35+
collective.finalize()
36+
37+
38+
def get_rank() -> int:
39+
"""Get rank of current process.
40+
Returns
41+
-------
42+
rank : int
43+
Rank of current process.
44+
"""
45+
return collective.get_rank()
46+
47+
48+
def get_world_size() -> int:
49+
"""Get total number workers.
50+
Returns
51+
-------
52+
n : int
53+
Total number of process.
54+
"""
55+
return collective.get_world_size()
56+
57+
58+
def is_distributed() -> int:
59+
"""If rabit is distributed."""
60+
return collective.is_distributed()
61+
62+
63+
def tracker_print(msg: Any) -> None:
64+
"""Print message to the tracker.
65+
This function can be used to communicate the information of
66+
the progress to the tracker
67+
Parameters
68+
----------
69+
msg : str
70+
The message to be printed to tracker.
71+
"""
72+
collective.communicator_print(msg)
73+
74+
75+
def get_processor_name() -> bytes:
76+
"""Get the processor name.
77+
Returns
78+
-------
79+
name : str
80+
the name of processor(host)
81+
"""
82+
return collective.get_processor_name().encode()
83+
84+
85+
T = TypeVar("T") # pylint:disable=invalid-name
86+
87+
88+
def broadcast(data: T, root: int) -> T:
89+
"""Broadcast object from one node to all other nodes.
90+
Parameters
91+
----------
92+
data : any type that can be pickled
93+
Input data, if current rank does not equal root, this can be None
94+
root : int
95+
Rank of the node to broadcast data from.
96+
Returns
97+
-------
98+
object : int
99+
the result of broadcast.
100+
"""
101+
return collective.broadcast(data, root)
102+
103+
104+
@unique
105+
class Op(IntEnum):
106+
"""Supported operations for rabit."""
107+
MAX = 0
108+
MIN = 1
109+
SUM = 2
110+
OR = 3
111+
112+
113+
def allreduce( # pylint:disable=invalid-name
114+
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
115+
) -> np.ndarray:
116+
"""Perform allreduce, return the result.
117+
Parameters
118+
----------
119+
data :
120+
Input data.
121+
op :
122+
Reduction operators, can be MIN, MAX, SUM, BITOR
123+
prepare_fun :
124+
Lazy preprocessing function, if it is not None, prepare_fun(data)
125+
will be called by the function before performing allreduce, to initialize the data
126+
If the result of Allreduce can be recovered directly,
127+
then prepare_fun will NOT be called
128+
Returns
129+
-------
130+
result :
131+
The result of allreduce, have same shape as data
132+
Notes
133+
-----
134+
This function is not thread-safe.
135+
"""
136+
if prepare_fun is None:
137+
return collective.allreduce(data, collective.Op(op))
138+
raise Exception("preprocessing function is no longer supported")
139+
140+
141+
def version_number() -> int:
142+
"""Returns version number of current stored model.
143+
This means how many calls to CheckPoint we made so far.
144+
Returns
145+
-------
146+
version : int
147+
Version number of currently stored model
148+
"""
149+
return 0
150+
151+
152+
class RabitContext:
153+
"""A context controlling rabit initialization and finalization."""
154+
155+
def __init__(self, args: List[bytes] = None) -> None:
156+
if args is None:
157+
args = []
158+
self.args = args
159+
160+
def __enter__(self) -> None:
161+
init(self.args)
162+
assert is_distributed()
163+
LOGGER.warning(_deprecation_warning())
164+
LOGGER.debug("-------------- rabit say hello ------------------")
165+
166+
def __exit__(self, *args: List) -> None:
167+
finalize()
168+
LOGGER.debug("--------------- rabit say bye ------------------")

tests/python/test_collective.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,37 @@ def test_rabit_communicator():
3939
assert worker.exitcode == 0
4040

4141

42+
# TODO(rongou): remove this once we remove the rabit api.
43+
def run_rabit_api_worker(rabit_env, world_size):
44+
with xgb.rabit.RabitContext(rabit_env):
45+
assert xgb.rabit.get_world_size() == world_size
46+
assert xgb.rabit.is_distributed()
47+
assert xgb.rabit.get_processor_name().decode() == socket.gethostname()
48+
ret = xgb.rabit.broadcast('test1234', 0)
49+
assert str(ret) == 'test1234'
50+
ret = xgb.rabit.allreduce(np.asarray([1, 2, 3]), xgb.rabit.Op.SUM)
51+
assert np.array_equal(ret, np.asarray([2, 4, 6]))
52+
53+
54+
# TODO(rongou): remove this once we remove the rabit api.
55+
def test_rabit_api():
56+
world_size = 2
57+
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
58+
tracker.start(world_size)
59+
rabit_env = []
60+
for k, v in tracker.worker_envs().items():
61+
rabit_env.append(f"{k}={v}".encode())
62+
workers = []
63+
for _ in range(world_size):
64+
worker = multiprocessing.Process(target=run_rabit_api_worker,
65+
args=(rabit_env, world_size))
66+
workers.append(worker)
67+
worker.start()
68+
for worker in workers:
69+
worker.join()
70+
assert worker.exitcode == 0
71+
72+
4273
def run_federated_worker(port, world_size, rank):
4374
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
4475
federated_server_address=f'localhost:{port}',

0 commit comments

Comments
 (0)