|
| 1 | +"""Compatibility shim for xgboost.rabit; to be removed in 2.0""" |
| 2 | +import logging |
| 3 | +import warnings |
| 4 | +from enum import IntEnum, unique |
| 5 | +from typing import Any, TypeVar, Callable, Optional, List |
| 6 | + |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from . import collective |
| 10 | + |
| 11 | +LOGGER = logging.getLogger("[xgboost.rabit]") |
| 12 | + |
| 13 | + |
| 14 | +def _deprecation_warning() -> str: |
| 15 | + return ( |
| 16 | + "The xgboost.rabit submodule is marked as deprecated in 1.7 and will be removed " |
| 17 | + "in 2.0. Please use xgboost.collective instead." |
| 18 | + ) |
| 19 | + |
| 20 | + |
| 21 | +def init(args: Optional[List[bytes]] = None) -> None: |
| 22 | + """Initialize the rabit library with arguments""" |
| 23 | + warnings.warn(_deprecation_warning(), FutureWarning) |
| 24 | + parsed = {} |
| 25 | + if args: |
| 26 | + for arg in args: |
| 27 | + kv = arg.decode().split('=') |
| 28 | + if len(kv) == 2: |
| 29 | + parsed[kv[0]] = kv[1] |
| 30 | + collective.init(**parsed) |
| 31 | + |
| 32 | + |
| 33 | +def finalize() -> None: |
| 34 | + """Finalize the process, notify tracker everything is done.""" |
| 35 | + collective.finalize() |
| 36 | + |
| 37 | + |
| 38 | +def get_rank() -> int: |
| 39 | + """Get rank of current process. |
| 40 | + Returns |
| 41 | + ------- |
| 42 | + rank : int |
| 43 | + Rank of current process. |
| 44 | + """ |
| 45 | + return collective.get_rank() |
| 46 | + |
| 47 | + |
| 48 | +def get_world_size() -> int: |
| 49 | + """Get total number workers. |
| 50 | + Returns |
| 51 | + ------- |
| 52 | + n : int |
| 53 | + Total number of process. |
| 54 | + """ |
| 55 | + return collective.get_world_size() |
| 56 | + |
| 57 | + |
| 58 | +def is_distributed() -> int: |
| 59 | + """If rabit is distributed.""" |
| 60 | + return collective.is_distributed() |
| 61 | + |
| 62 | + |
| 63 | +def tracker_print(msg: Any) -> None: |
| 64 | + """Print message to the tracker. |
| 65 | + This function can be used to communicate the information of |
| 66 | + the progress to the tracker |
| 67 | + Parameters |
| 68 | + ---------- |
| 69 | + msg : str |
| 70 | + The message to be printed to tracker. |
| 71 | + """ |
| 72 | + collective.communicator_print(msg) |
| 73 | + |
| 74 | + |
| 75 | +def get_processor_name() -> bytes: |
| 76 | + """Get the processor name. |
| 77 | + Returns |
| 78 | + ------- |
| 79 | + name : str |
| 80 | + the name of processor(host) |
| 81 | + """ |
| 82 | + return collective.get_processor_name().encode() |
| 83 | + |
| 84 | + |
| 85 | +T = TypeVar("T") # pylint:disable=invalid-name |
| 86 | + |
| 87 | + |
| 88 | +def broadcast(data: T, root: int) -> T: |
| 89 | + """Broadcast object from one node to all other nodes. |
| 90 | + Parameters |
| 91 | + ---------- |
| 92 | + data : any type that can be pickled |
| 93 | + Input data, if current rank does not equal root, this can be None |
| 94 | + root : int |
| 95 | + Rank of the node to broadcast data from. |
| 96 | + Returns |
| 97 | + ------- |
| 98 | + object : int |
| 99 | + the result of broadcast. |
| 100 | + """ |
| 101 | + return collective.broadcast(data, root) |
| 102 | + |
| 103 | + |
| 104 | +@unique |
| 105 | +class Op(IntEnum): |
| 106 | + """Supported operations for rabit.""" |
| 107 | + MAX = 0 |
| 108 | + MIN = 1 |
| 109 | + SUM = 2 |
| 110 | + OR = 3 |
| 111 | + |
| 112 | + |
| 113 | +def allreduce( # pylint:disable=invalid-name |
| 114 | + data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None |
| 115 | +) -> np.ndarray: |
| 116 | + """Perform allreduce, return the result. |
| 117 | + Parameters |
| 118 | + ---------- |
| 119 | + data : |
| 120 | + Input data. |
| 121 | + op : |
| 122 | + Reduction operators, can be MIN, MAX, SUM, BITOR |
| 123 | + prepare_fun : |
| 124 | + Lazy preprocessing function, if it is not None, prepare_fun(data) |
| 125 | + will be called by the function before performing allreduce, to initialize the data |
| 126 | + If the result of Allreduce can be recovered directly, |
| 127 | + then prepare_fun will NOT be called |
| 128 | + Returns |
| 129 | + ------- |
| 130 | + result : |
| 131 | + The result of allreduce, have same shape as data |
| 132 | + Notes |
| 133 | + ----- |
| 134 | + This function is not thread-safe. |
| 135 | + """ |
| 136 | + if prepare_fun is None: |
| 137 | + return collective.allreduce(data, collective.Op(op)) |
| 138 | + raise Exception("preprocessing function is no longer supported") |
| 139 | + |
| 140 | + |
| 141 | +def version_number() -> int: |
| 142 | + """Returns version number of current stored model. |
| 143 | + This means how many calls to CheckPoint we made so far. |
| 144 | + Returns |
| 145 | + ------- |
| 146 | + version : int |
| 147 | + Version number of currently stored model |
| 148 | + """ |
| 149 | + return 0 |
| 150 | + |
| 151 | + |
| 152 | +class RabitContext: |
| 153 | + """A context controlling rabit initialization and finalization.""" |
| 154 | + |
| 155 | + def __init__(self, args: List[bytes] = None) -> None: |
| 156 | + if args is None: |
| 157 | + args = [] |
| 158 | + self.args = args |
| 159 | + |
| 160 | + def __enter__(self) -> None: |
| 161 | + init(self.args) |
| 162 | + assert is_distributed() |
| 163 | + LOGGER.warning(_deprecation_warning()) |
| 164 | + LOGGER.debug("-------------- rabit say hello ------------------") |
| 165 | + |
| 166 | + def __exit__(self, *args: List) -> None: |
| 167 | + finalize() |
| 168 | + LOGGER.debug("--------------- rabit say bye ------------------") |
0 commit comments