3
3
*/
4
4
#include " threading_utils.h"
5
5
6
- #include < fstream>
7
- #include < string>
6
+ #include < algorithm> // for max
7
+ #include < exception> // for exception
8
+ #include < filesystem> // for path, exists
9
+ #include < fstream> // for ifstream
10
+ #include < string> // for string
8
11
9
- #include " xgboost/logging .h"
12
+ #include " common .h" // for DivRoundUp
10
13
11
- namespace xgboost {
12
- namespace common {
13
- int32_t GetCfsCPUCount () noexcept {
14
+ namespace xgboost ::common {
15
+ /* *
16
+ * Modified from
17
+ * github.com/psiha/sweater/blob/master/include/boost/sweater/hardware_concurrency.hpp
18
+ *
19
+ * MIT License: Copyright (c) 2016 Domagoj Šarić
20
+ */
21
+ std::int32_t GetCGroupV1Count (std::filesystem::path const & quota_path,
22
+ std::filesystem::path const & peroid_path) {
14
23
#if defined(__linux__)
15
24
// https://bugs.openjdk.java.net/browse/JDK-8146115
16
25
// http://hg.openjdk.java.net/jdk/hs/rev/7f22774a5f42
@@ -31,15 +40,56 @@ int32_t GetCfsCPUCount() noexcept {
31
40
}
32
41
};
33
42
// complete fair scheduler from Linux
34
- auto const cfs_quota (read_int (" /sys/fs/cgroup/cpu/cpu.cfs_quota_us " ));
35
- auto const cfs_period (read_int (" /sys/fs/cgroup/cpu/cpu.cfs_period_us " ));
43
+ auto const cfs_quota (read_int (quota_path. c_str () ));
44
+ auto const cfs_period (read_int (peroid_path. c_str () ));
36
45
if ((cfs_quota > 0 ) && (cfs_period > 0 )) {
37
46
return std::max (cfs_quota / cfs_period, 1 );
38
47
}
39
48
#endif // defined(__linux__)
40
49
return -1 ;
41
50
}
42
51
52
+ std::int32_t GetCGroupV2Count (std::filesystem::path const & bandwidth_path) noexcept (true ) {
53
+ std::int32_t cnt{-1 };
54
+ #if defined(__linux__)
55
+ namespace fs = std::filesystem;
56
+
57
+ std::int32_t a{0 }, b{0 };
58
+
59
+ auto warn = [] { LOG (WARNING) << " Invalid cgroupv2 file." ; };
60
+ try {
61
+ std::ifstream fin{bandwidth_path, std::ios::in};
62
+ fin >> a;
63
+ fin >> b;
64
+ } catch (std::exception const &) {
65
+ warn ();
66
+ return cnt;
67
+ }
68
+ if (a > 0 && b > 0 ) {
69
+ cnt = std::max (common::DivRoundUp (a, b), 1 );
70
+ }
71
+ #endif // defined(__linux__)
72
+ return cnt;
73
+ }
74
+
75
+ std::int32_t GetCfsCPUCount () noexcept {
76
+ namespace fs = std::filesystem;
77
+ fs::path const bandwidth_path{" /sys/fs/cgroup/cpu.max" };
78
+ auto has_v2 = fs::exists (bandwidth_path);
79
+ if (has_v2) {
80
+ return GetCGroupV2Count (bandwidth_path);
81
+ }
82
+
83
+ fs::path const quota_path{" /sys/fs/cgroup/cpu/cpu.cfs_quota_us" };
84
+ fs::path const peroid_path{" /sys/fs/cgroup/cpu/cpu.cfs_period_us" };
85
+ auto has_v1 = fs::exists (quota_path) && fs::exists (peroid_path);
86
+ if (has_v1) {
87
+ return GetCGroupV1Count (quota_path, peroid_path);
88
+ }
89
+
90
+ return -1 ;
91
+ }
92
+
43
93
std::int32_t OmpGetNumThreads (std::int32_t n_threads) {
44
94
// Don't use parallel if we are in a parallel region.
45
95
if (omp_in_parallel ()) {
@@ -54,5 +104,4 @@ std::int32_t OmpGetNumThreads(std::int32_t n_threads) {
54
104
n_threads = std::max (n_threads, 1 );
55
105
return n_threads;
56
106
}
57
- } // namespace common
58
- } // namespace xgboost
107
+ } // namespace xgboost::common
0 commit comments