@@ -1117,8 +1117,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
1117
1117
*
1118
1118
* @return 0 when success, -1 when failure happens
1119
1119
*/
1120
- XGB_DLL int XGBoosterPredictFromColumnar (BoosterHandle handle, char const *array_interface ,
1121
- char const *c_json_config , DMatrixHandle m,
1120
+ XGB_DLL int XGBoosterPredictFromColumnar (BoosterHandle handle, char const *values ,
1121
+ char const *config , DMatrixHandle m,
1122
1122
bst_ulong const **out_shape, bst_ulong *out_dim,
1123
1123
const float **out_result);
1124
1124
@@ -1514,16 +1514,37 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
1514
1514
*
1515
1515
* @brief Experimental support for exposing internal communicator in XGBoost.
1516
1516
*
1517
+ * @note This is still under development.
1518
+ *
1519
+ * The collective communicator in XGBoost evolved from the `rabit` project of dmlc but has
1520
+ * changed significantly since its adoption. It consists of a tracker and a set of
1521
+ * workers. The tracker is responsible for bootstrapping the communication group and
1522
+ * handling centralized tasks like logging. The workers are actual communicators
1523
+ * performing collective tasks like allreduce.
1524
+ *
1525
+ * To use the collective implementation, one needs to first create a tracker with
1526
+ * corresponding parameters, then get the arguments for workers using
1527
+ * XGTrackerWorkerArgs(). The obtained arguments can then be passed to the
1528
+ * XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a
1529
+ * XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses
1530
+ * `std::thread` in C++, which has undefined behavior in a C++ destructor due to the
1531
+ * runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the
1532
+ * runtime is shutting down. This requirement is similar to a Python thread or socket,
1533
+ * which should not be relied upon in a `__del__` function.
1534
+ *
1535
+ * Since it's used as a part of XGBoost, errors will be returned when a XGBoost function
1536
+ * is called, for instance, training a booster might return a connection error.
1537
+ *
1517
1538
* @{
1518
1539
*/
1519
1540
1520
1541
/* *
1521
- * @brief Handle to tracker.
1542
+ * @brief Handle to the tracker.
1522
1543
*
1523
1544
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
1524
- * other one is `federated`.
1545
+ * other one is `federated`. `rabit` is used for normal collective communication, while
1546
+ * `federated` is used for federated learning.
1525
1547
*
1526
- * This is still under development.
1527
1548
*/
1528
1549
typedef void *TrackerHandle; /* NOLINT */
1529
1550
@@ -1532,17 +1553,23 @@ typedef void *TrackerHandle; /* NOLINT */
1532
1553
*
1533
1554
* @param config JSON encoded parameters.
1534
1555
*
1535
- * - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
1536
- * and `federated`.
1556
+ * - dmlc_communicator: String, the type of tracker to create. Available options are
1557
+ * `rabit` and `federated`. See @ref TrackerHandle for more info .
1537
1558
* - n_workers: Integer, the number of workers.
1538
1559
* - port: (Optional) Integer, the port this tracker should listen to.
1539
- * - timeout: (Optional) Integer, timeout in seconds for various networking operations.
1560
+ * - timeout: (Optional) Integer, timeout in seconds for various networking
1561
+ operations. Default is 300 seconds.
1540
1562
*
1541
1563
* Some configurations are `rabit` specific:
1564
+ *
1542
1565
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
1566
+ * This can be useful when the communicator cannot reliably obtain the host address.
1567
+ * - sortby: (Optional) Integer.
1568
+ * + 0: Sort workers by their host name.
1569
+ * + 1: Sort workers by task IDs.
1543
1570
*
1544
1571
* Some `federated` specific configurations:
1545
- * - federated_secure: Boolean, whether this is a secure server.
1572
+ * - federated_secure: Boolean, whether this is a secure server. False for testing.
1546
1573
* - server_key_path: Path to the server key. Used only if this is a secure server.
1547
1574
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
1548
1575
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
@@ -1598,129 +1625,128 @@ XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
1598
1625
*/
1599
1626
XGB_DLL int XGTrackerFree (TrackerHandle handle);
1600
1627
1601
- /* !
1602
- * \ brief Initialize the collective communicator.
1628
+ /* *
1629
+ * @ brief Initialize the collective communicator.
1603
1630
*
1604
1631
* Currently the communicator API is experimental, function signatures may change in the future
1605
1632
* without notice.
1606
1633
*
1607
- * Call this once before using anything.
1608
- *
1609
- * The additional configuration is not required. Usually the communicator will detect settings
1610
- * from environment variables.
1634
+ * Call this once in the worker process before using anything. Please make sure
1635
+ * XGCommunicatorFinalize() is called after use. The initialized commuicator is a global
1636
+ * thread-local variable.
1611
1637
*
1612
- * \ param config JSON encoded configuration. Accepted JSON keys are:
1613
- * - xgboost_communicator : The type of the communicator. Can be set as an environment variable .
1638
+ * @ param config JSON encoded configuration. Accepted JSON keys are:
1639
+ * - dmlc_communicator : The type of the communicator, this should match the tracker type .
1614
1640
* * rabit: Use Rabit. This is the default if the type is unspecified.
1615
1641
* * federated: Use the gRPC interface for Federated Learning.
1616
- * Only applicable to the Rabit communicator (these are case-sensitive):
1617
- * - rabit_tracker_uri: Hostname of the tracker.
1618
- * - rabit_tracker_port: Port number of the tracker.
1619
- * - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
1620
- * - rabit_world_size: Total number of workers.
1621
- * - rabit_timeout: Enable timeout.
1622
- * - rabit_timeout_sec: Timeout in seconds.
1623
- * Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
1624
- * environment variables):
1625
- * - DMLC_TRACKER_URI: Hostname of the tracker.
1626
- * - DMLC_TRACKER_PORT: Port number of the tracker.
1627
- * - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
1628
- * - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
1629
- * - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with
1630
- * `USE_DLOPEN_NCCL`.
1631
- * Only applicable to the Federated communicator (use upper case for environment variables, use
1642
+ *
1643
+ * Only applicable to the `rabit` communicator:
1644
+ * - dmlc_tracker_uri: Hostname or IP address of the tracker.
1645
+ * - dmlc_tracker_port: Port number of the tracker.
1646
+ * - dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
1647
+ * - dmlc_retry: The number of retries for connection failure.
1648
+ * - dmlc_timeout: Timeout in seconds.
1649
+ * - dmlc_nccl_path: Path to the nccl shared library `libnccl.so`.
1650
+ *
1651
+ * Only applicable to the `federated` communicator (use upper case for environment variables, use
1632
1652
* lower case for runtime configuration):
1633
1653
* - federated_server_address: Address of the federated server.
1634
1654
* - federated_world_size: Number of federated workers.
1635
1655
* - federated_rank: Rank of the current worker.
1636
- * - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
1637
- * - federated_client_key: Client key file path. Only needed for the SSL mode.
1638
- * - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
1639
- * \return 0 for success, -1 for failure.
1656
+ * - federated_server_cert_path: Server certificate file path. Only needed for the SSL mode.
1657
+ * - federated_client_key_path: Client key file path. Only needed for the SSL mode.
1658
+ * - federated_client_cert_path: Client certificate file path. Only needed for the SSL mode.
1659
+ *
1660
+ * @return 0 for success, -1 for failure.
1640
1661
*/
1641
1662
XGB_DLL int XGCommunicatorInit (char const * config);
1642
1663
1643
- /* !
1644
- * \ brief Finalize the collective communicator.
1664
+ /* *
1665
+ * @ brief Finalize the collective communicator.
1645
1666
*
1646
- * Call this function after you finished all jobs.
1667
+ * Call this function after you have finished all jobs.
1647
1668
*
1648
- * \ return 0 for success, -1 for failure.
1669
+ * @ return 0 for success, -1 for failure.
1649
1670
*/
1650
1671
XGB_DLL int XGCommunicatorFinalize (void );
1651
1672
1652
- /* !
1653
- * \ brief Get rank of current process.
1673
+ /* *
1674
+ * @ brief Get rank of the current process.
1654
1675
*
1655
- * \ return Rank of the worker.
1676
+ * @ return Rank of the worker.
1656
1677
*/
1657
1678
XGB_DLL int XGCommunicatorGetRank (void );
1658
1679
1659
- /* !
1660
- * \ brief Get total number of processes.
1680
+ /* *
1681
+ * @ brief Get the total number of processes.
1661
1682
*
1662
- * \ return Total world size.
1683
+ * @ return Total world size.
1663
1684
*/
1664
1685
XGB_DLL int XGCommunicatorGetWorldSize (void );
1665
1686
1666
- /* !
1667
- * \ brief Get if the communicator is distributed.
1687
+ /* *
1688
+ * @ brief Get if the communicator is distributed.
1668
1689
*
1669
- * \ return True if the communicator is distributed.
1690
+ * @ return True if the communicator is distributed.
1670
1691
*/
1671
1692
XGB_DLL int XGCommunicatorIsDistributed (void );
1672
1693
1673
- /* !
1674
- * \ brief Print the message to the communicator .
1694
+ /* *
1695
+ * @ brief Print the message to the tracker .
1675
1696
*
1676
- * This function can be used to communicate the information of the progress to the user who monitors
1677
- * the communicator .
1697
+ * This function can be used to communicate the information of the progress to the user
1698
+ * who monitors the tracker .
1678
1699
*
1679
- * \ param message The message to be printed.
1680
- * \ return 0 for success, -1 for failure.
1700
+ * @ param message The message to be printed.
1701
+ * @ return 0 for success, -1 for failure.
1681
1702
*/
1682
1703
XGB_DLL int XGCommunicatorPrint (char const *message);
1683
1704
1684
- /* !
1685
- * \ brief Get the name of the processor.
1705
+ /* *
1706
+ * @ brief Get the name of the processor.
1686
1707
*
1687
- * \ param name_str Pointer to received returned processor name.
1688
- * \ return 0 for success, -1 for failure.
1708
+ * @ param name_str Pointer to received returned processor name.
1709
+ * @ return 0 for success, -1 for failure.
1689
1710
*/
1690
1711
XGB_DLL int XGCommunicatorGetProcessorName (const char ** name_str);
1691
1712
1692
- /* !
1693
- * \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
1713
+ /* *
1714
+ * @brief Broadcast a memory region to all others from root. This function is NOT
1715
+ * thread-safe.
1694
1716
*
1695
1717
* Example:
1696
- * \ code
1718
+ * @ code
1697
1719
* int a = 1;
1698
1720
* Broadcast(&a, sizeof(a), root);
1699
- * \ endcode
1721
+ * @ endcode
1700
1722
*
1701
- * \ param send_receive_buffer Pointer to the send or receive buffer.
1702
- * \ param size Size of the data.
1703
- * \ param root The process rank to broadcast from.
1704
- * \ return 0 for success, -1 for failure.
1723
+ * @ param send_receive_buffer Pointer to the send or receive buffer.
1724
+ * @ param size Size of the data in bytes .
1725
+ * @ param root The process rank to broadcast from.
1726
+ * @ return 0 for success, -1 for failure.
1705
1727
*/
1706
1728
XGB_DLL int XGCommunicatorBroadcast (void *send_receive_buffer, size_t size, int root);
1707
1729
1708
- /* !
1709
- * \ brief Perform in-place allreduce. This function is NOT thread-safe.
1730
+ /* *
1731
+ * @ brief Perform in-place allreduce. This function is NOT thread-safe.
1710
1732
*
1711
1733
* Example Usage: the following code gives sum of the result
1712
- * \code
1713
- * vector<int> data(10);
1734
+ * @code
1735
+ * enum class Op {
1736
+ * kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
1737
+ * };
1738
+ * std::vector<int> data(10);
1714
1739
* ...
1715
- * Allreduce(& data[0] , data.size(), DataType:kInt32, Op::kSum);
1740
+ * Allreduce(data.data() , data.size(), DataType:kInt32, Op::kSum);
1716
1741
* ...
1717
- * \ endcode
1742
+ * @ endcode
1718
1743
1719
- * \param send_receive_buffer Buffer for both sending and receiving data.
1720
- * \param count Number of elements to be reduced.
1721
- * \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
1722
- * \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
1723
- * \return 0 for success, -1 for failure.
1744
+ * @param send_receive_buffer Buffer for both sending and receiving data.
1745
+ * @param count Number of elements to be reduced.
1746
+ * @param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
1747
+ * @param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
1748
+ *
1749
+ * @return 0 for success, -1 for failure.
1724
1750
*/
1725
1751
XGB_DLL int XGCommunicatorAllreduce (void *send_receive_buffer, size_t count, int data_type, int op);
1726
1752
0 commit comments