Skip to content

Commit a5a5810

Browse files
authored
Revamp the rabit implementation. (dmlc#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
1 parent ba9b4cb commit a5a5810

File tree

195 files changed

+2750
-9216
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

195 files changed

+2750
-9216
lines changed

CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ option(USE_DMLC_GTEST "Use google tests bundled with dmlc-core submodule" OFF)
6969
option(USE_DEVICE_DEBUG "Generate CUDA device debug info." OFF)
7070
option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
7171
set(NVTX_HEADER_DIR "" CACHE PATH "Path to the stand-alone nvtx header")
72-
option(RABIT_MOCK "Build rabit with mock" OFF)
7372
option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF)
7473
option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF)
7574
## CUDA
@@ -282,9 +281,6 @@ if(MSVC)
282281
endif()
283282
endif()
284283

285-
# rabit
286-
add_subdirectory(rabit)
287-
288284
# core xgboost
289285
add_subdirectory(${xgboost_SOURCE_DIR}/src)
290286
target_link_libraries(objxgboost PUBLIC dmlc)

R-package/src/Makevars.in

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,7 @@ OBJECTS= \
106106
$(PKGROOT)/src/collective/comm.o \
107107
$(PKGROOT)/src/collective/comm_group.o \
108108
$(PKGROOT)/src/collective/coll.o \
109-
$(PKGROOT)/src/collective/communicator-inl.o \
110109
$(PKGROOT)/src/collective/tracker.o \
111-
$(PKGROOT)/src/collective/communicator.o \
112-
$(PKGROOT)/src/collective/in_memory_communicator.o \
113110
$(PKGROOT)/src/collective/in_memory_handler.o \
114111
$(PKGROOT)/src/collective/loop.o \
115112
$(PKGROOT)/src/collective/socket.o \
@@ -134,7 +131,4 @@ OBJECTS= \
134131
$(PKGROOT)/src/common/version.o \
135132
$(PKGROOT)/src/c_api/c_api.o \
136133
$(PKGROOT)/src/c_api/c_api_error.o \
137-
$(PKGROOT)/amalgamation/dmlc-minimum0.o \
138-
$(PKGROOT)/rabit/src/engine.o \
139-
$(PKGROOT)/rabit/src/rabit_c_api.o \
140-
$(PKGROOT)/rabit/src/allreduce_base.o
134+
$(PKGROOT)/amalgamation/dmlc-minimum0.o

R-package/src/Makevars.win

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,7 @@ OBJECTS= \
106106
$(PKGROOT)/src/collective/comm.o \
107107
$(PKGROOT)/src/collective/comm_group.o \
108108
$(PKGROOT)/src/collective/coll.o \
109-
$(PKGROOT)/src/collective/communicator-inl.o \
110109
$(PKGROOT)/src/collective/tracker.o \
111-
$(PKGROOT)/src/collective/communicator.o \
112-
$(PKGROOT)/src/collective/in_memory_communicator.o \
113110
$(PKGROOT)/src/collective/in_memory_handler.o \
114111
$(PKGROOT)/src/collective/loop.o \
115112
$(PKGROOT)/src/collective/socket.o \
@@ -134,7 +131,4 @@ OBJECTS= \
134131
$(PKGROOT)/src/common/version.o \
135132
$(PKGROOT)/src/c_api/c_api.o \
136133
$(PKGROOT)/src/c_api/c_api_error.o \
137-
$(PKGROOT)/amalgamation/dmlc-minimum0.o \
138-
$(PKGROOT)/rabit/src/engine.o \
139-
$(PKGROOT)/rabit/src/rabit_c_api.o \
140-
$(PKGROOT)/rabit/src/allreduce_base.o
134+
$(PKGROOT)/amalgamation/dmlc-minimum0.o

cmake/Utils.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ function(xgboost_set_cuda_flags target)
151151
target_include_directories(
152152
${target} PRIVATE
153153
${xgboost_SOURCE_DIR}/gputreeshap
154+
${xgboost_SOURCE_DIR}/rabit/include
154155
${CUDAToolkit_INCLUDE_DIRS})
155156

156157
if(MSVC)

demo/dask/cpu_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def main(client: Client) -> None:
1616
m = 100000
1717
n = 100
1818
rng = da.random.default_rng(1)
19-
X = rng.normal(size=(m, n))
19+
X = rng.normal(size=(m, n), chunks=(10000, -1))
2020
y = X.sum(axis=1)
2121

2222
# DaskDMatrix acts like normal DMatrix, works as a proxy for local

include/xgboost/c_api.h

Lines changed: 105 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,8 +1117,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
11171117
*
11181118
* @return 0 when success, -1 when failure happens
11191119
*/
1120-
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface,
1121-
char const *c_json_config, DMatrixHandle m,
1120+
XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *values,
1121+
char const *config, DMatrixHandle m,
11221122
bst_ulong const **out_shape, bst_ulong *out_dim,
11231123
const float **out_result);
11241124

@@ -1514,16 +1514,37 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
15141514
*
15151515
* @brief Experimental support for exposing internal communicator in XGBoost.
15161516
*
1517+
* @note This is still under development.
1518+
*
1519+
* The collective communicator in XGBoost evolved from the `rabit` project of dmlc but has
1520+
* changed significantly since its adoption. It consists of a tracker and a set of
1521+
* workers. The tracker is responsible for bootstrapping the communication group and
1522+
* handling centralized tasks like logging. The workers are actual communicators
1523+
* performing collective tasks like allreduce.
1524+
*
1525+
* To use the collective implementation, one needs to first create a tracker with
1526+
* corresponding parameters, then get the arguments for workers using
1527+
* XGTrackerWorkerArgs(). The obtained arguments can then be passed to the
1528+
* XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a
1529+
* XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses
1530+
* `std::thread` in C++, which has undefined behavior in a C++ destructor due to the
1531+
* runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the
1532+
* runtime is shutting down. This requirement is similar to a Python thread or socket,
1533+
* which should not be relied upon in a `__del__` function.
1534+
*
1535+
* Since it's used as a part of XGBoost, errors will be returned when a XGBoost function
1536+
* is called, for instance, training a booster might return a connection error.
1537+
*
15171538
* @{
15181539
*/
15191540

15201541
/**
1521-
* @brief Handle to tracker.
1542+
* @brief Handle to the tracker.
15221543
*
15231544
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
1524-
* other one is `federated`.
1545+
* other one is `federated`. `rabit` is used for normal collective communication, while
1546+
* `federated` is used for federated learning.
15251547
*
1526-
* This is still under development.
15271548
*/
15281549
typedef void *TrackerHandle; /* NOLINT */
15291550

@@ -1532,17 +1553,23 @@ typedef void *TrackerHandle; /* NOLINT */
15321553
*
15331554
* @param config JSON encoded parameters.
15341555
*
1535-
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
1536-
* and `federated`.
1556+
* - dmlc_communicator: String, the type of tracker to create. Available options are
1557+
* `rabit` and `federated`. See @ref TrackerHandle for more info.
15371558
* - n_workers: Integer, the number of workers.
15381559
* - port: (Optional) Integer, the port this tracker should listen to.
1539-
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
1560+
* - timeout: (Optional) Integer, timeout in seconds for various networking
1561+
operations. Default is 300 seconds.
15401562
*
15411563
* Some configurations are `rabit` specific:
1564+
*
15421565
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
1566+
* This can be useful when the communicator cannot reliably obtain the host address.
1567+
* - sortby: (Optional) Integer.
1568+
* + 0: Sort workers by their host name.
1569+
* + 1: Sort workers by task IDs.
15431570
*
15441571
* Some `federated` specific configurations:
1545-
* - federated_secure: Boolean, whether this is a secure server.
1572+
* - federated_secure: Boolean, whether this is a secure server. False for testing.
15461573
* - server_key_path: Path to the server key. Used only if this is a secure server.
15471574
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
15481575
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
@@ -1598,129 +1625,128 @@ XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
15981625
*/
15991626
XGB_DLL int XGTrackerFree(TrackerHandle handle);
16001627

1601-
/*!
1602-
* \brief Initialize the collective communicator.
1628+
/**
1629+
* @brief Initialize the collective communicator.
16031630
*
16041631
* Currently the communicator API is experimental, function signatures may change in the future
16051632
* without notice.
16061633
*
1607-
* Call this once before using anything.
1608-
*
1609-
* The additional configuration is not required. Usually the communicator will detect settings
1610-
* from environment variables.
1634+
* Call this once in the worker process before using anything. Please make sure
1635+
* XGCommunicatorFinalize() is called after use. The initialized commuicator is a global
1636+
* thread-local variable.
16111637
*
1612-
* \param config JSON encoded configuration. Accepted JSON keys are:
1613-
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
1638+
* @param config JSON encoded configuration. Accepted JSON keys are:
1639+
* - dmlc_communicator: The type of the communicator, this should match the tracker type.
16141640
* * rabit: Use Rabit. This is the default if the type is unspecified.
16151641
* * federated: Use the gRPC interface for Federated Learning.
1616-
* Only applicable to the Rabit communicator (these are case-sensitive):
1617-
* - rabit_tracker_uri: Hostname of the tracker.
1618-
* - rabit_tracker_port: Port number of the tracker.
1619-
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
1620-
* - rabit_world_size: Total number of workers.
1621-
* - rabit_timeout: Enable timeout.
1622-
* - rabit_timeout_sec: Timeout in seconds.
1623-
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
1624-
* environment variables):
1625-
* - DMLC_TRACKER_URI: Hostname of the tracker.
1626-
* - DMLC_TRACKER_PORT: Port number of the tracker.
1627-
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
1628-
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
1629-
* - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with
1630-
* `USE_DLOPEN_NCCL`.
1631-
* Only applicable to the Federated communicator (use upper case for environment variables, use
1642+
*
1643+
* Only applicable to the `rabit` communicator:
1644+
* - dmlc_tracker_uri: Hostname or IP address of the tracker.
1645+
* - dmlc_tracker_port: Port number of the tracker.
1646+
* - dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
1647+
* - dmlc_retry: The number of retries for connection failure.
1648+
* - dmlc_timeout: Timeout in seconds.
1649+
* - dmlc_nccl_path: Path to the nccl shared library `libnccl.so`.
1650+
*
1651+
* Only applicable to the `federated` communicator (use upper case for environment variables, use
16321652
* lower case for runtime configuration):
16331653
* - federated_server_address: Address of the federated server.
16341654
* - federated_world_size: Number of federated workers.
16351655
* - federated_rank: Rank of the current worker.
1636-
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
1637-
* - federated_client_key: Client key file path. Only needed for the SSL mode.
1638-
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
1639-
* \return 0 for success, -1 for failure.
1656+
* - federated_server_cert_path: Server certificate file path. Only needed for the SSL mode.
1657+
* - federated_client_key_path: Client key file path. Only needed for the SSL mode.
1658+
* - federated_client_cert_path: Client certificate file path. Only needed for the SSL mode.
1659+
*
1660+
* @return 0 for success, -1 for failure.
16401661
*/
16411662
XGB_DLL int XGCommunicatorInit(char const* config);
16421663

1643-
/*!
1644-
* \brief Finalize the collective communicator.
1664+
/**
1665+
* @brief Finalize the collective communicator.
16451666
*
1646-
* Call this function after you finished all jobs.
1667+
* Call this function after you have finished all jobs.
16471668
*
1648-
* \return 0 for success, -1 for failure.
1669+
* @return 0 for success, -1 for failure.
16491670
*/
16501671
XGB_DLL int XGCommunicatorFinalize(void);
16511672

1652-
/*!
1653-
* \brief Get rank of current process.
1673+
/**
1674+
* @brief Get rank of the current process.
16541675
*
1655-
* \return Rank of the worker.
1676+
* @return Rank of the worker.
16561677
*/
16571678
XGB_DLL int XGCommunicatorGetRank(void);
16581679

1659-
/*!
1660-
* \brief Get total number of processes.
1680+
/**
1681+
* @brief Get the total number of processes.
16611682
*
1662-
* \return Total world size.
1683+
* @return Total world size.
16631684
*/
16641685
XGB_DLL int XGCommunicatorGetWorldSize(void);
16651686

1666-
/*!
1667-
* \brief Get if the communicator is distributed.
1687+
/**
1688+
* @brief Get if the communicator is distributed.
16681689
*
1669-
* \return True if the communicator is distributed.
1690+
* @return True if the communicator is distributed.
16701691
*/
16711692
XGB_DLL int XGCommunicatorIsDistributed(void);
16721693

1673-
/*!
1674-
* \brief Print the message to the communicator.
1694+
/**
1695+
* @brief Print the message to the tracker.
16751696
*
1676-
* This function can be used to communicate the information of the progress to the user who monitors
1677-
* the communicator.
1697+
* This function can be used to communicate the information of the progress to the user
1698+
* who monitors the tracker.
16781699
*
1679-
* \param message The message to be printed.
1680-
* \return 0 for success, -1 for failure.
1700+
* @param message The message to be printed.
1701+
* @return 0 for success, -1 for failure.
16811702
*/
16821703
XGB_DLL int XGCommunicatorPrint(char const *message);
16831704

1684-
/*!
1685-
* \brief Get the name of the processor.
1705+
/**
1706+
* @brief Get the name of the processor.
16861707
*
1687-
* \param name_str Pointer to received returned processor name.
1688-
* \return 0 for success, -1 for failure.
1708+
* @param name_str Pointer to received returned processor name.
1709+
* @return 0 for success, -1 for failure.
16891710
*/
16901711
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);
16911712

1692-
/*!
1693-
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
1713+
/**
1714+
* @brief Broadcast a memory region to all others from root. This function is NOT
1715+
* thread-safe.
16941716
*
16951717
* Example:
1696-
* \code
1718+
* @code
16971719
* int a = 1;
16981720
* Broadcast(&a, sizeof(a), root);
1699-
* \endcode
1721+
* @endcode
17001722
*
1701-
* \param send_receive_buffer Pointer to the send or receive buffer.
1702-
* \param size Size of the data.
1703-
* \param root The process rank to broadcast from.
1704-
* \return 0 for success, -1 for failure.
1723+
* @param send_receive_buffer Pointer to the send or receive buffer.
1724+
* @param size Size of the data in bytes.
1725+
* @param root The process rank to broadcast from.
1726+
* @return 0 for success, -1 for failure.
17051727
*/
17061728
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);
17071729

1708-
/*!
1709-
* \brief Perform in-place allreduce. This function is NOT thread-safe.
1730+
/**
1731+
* @brief Perform in-place allreduce. This function is NOT thread-safe.
17101732
*
17111733
* Example Usage: the following code gives sum of the result
1712-
* \code
1713-
* vector<int> data(10);
1734+
* @code
1735+
* enum class Op {
1736+
* kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
1737+
* };
1738+
* std::vector<int> data(10);
17141739
* ...
1715-
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
1740+
* Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum);
17161741
* ...
1717-
* \endcode
1742+
* @endcode
17181743
1719-
* \param send_receive_buffer Buffer for both sending and receiving data.
1720-
* \param count Number of elements to be reduced.
1721-
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
1722-
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
1723-
* \return 0 for success, -1 for failure.
1744+
* @param send_receive_buffer Buffer for both sending and receiving data.
1745+
* @param count Number of elements to be reduced.
1746+
* @param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
1747+
* @param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
1748+
*
1749+
* @return 0 for success, -1 for failure.
17241750
*/
17251751
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);
17261752

include/xgboost/collective/result.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,9 @@ struct ResultImpl {
5555
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
5656
#define __builtin_FILE() nullptr
5757
#define __builtin_LINE() (-1)
58-
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
59-
#else
60-
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
6158
#endif
59+
60+
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
6261
} // namespace detail
6362

6463
/**

0 commit comments

Comments
 (0)