Skip to content

Commit eae8d47

Browse files
committed
Tracking memory resources (rapidsai#2973)
Detailed tracking of (almost) all allocations on device and host. ```C++ // optionally pass an existing resource handle raft::resources res; // The tracking handle is a child of resource handle; it wraps all memory resources with statistics adaptors raft::memory_tracking_resources tracked(res, "allocations.csv", std::chrono::milliseconds(1)); // All allocations are logged to a .csv as long as `tracked` is alive cuvs::neighbors::cagra::build(tracked, ...); ``` This produces a CSV file with sampled allocations with a timeline and NVTX correlation ```csv timestamp_us,nvtx_depth,nvtx_range,host_current,host_total,pinned_current,pinned_total,managed_current,managed_total,device_current,device_total,workspace_current,workspace_total,large_workspace_current,large_workspace_total 198809,1,"hnsw::build<ACE>",20008,20008,0,0,0,0,148304,148304,0,0,0,0 199961,1,"hnsw::build<ACE>",20008,20008,0,0,0,0,15588304,15588304,0,0,0,0 201350,1,"hnsw::build<ACE>",0,20008,0,0,0,0,0,40385488,0,0,0,0 222216,3,"cagra::build_knn_graph<IVF-PQ>(5000000, 1536, 72)",1440000000,1440020008,0,0,0,0,0,40385488,0,0,0,0 273892,4,"ivf_pq::build(5000000, 1536)",1440020008,1440040016,0,0,0,0,40385488,80770976,0,0,0,0 304183,4,"ivf_pq::build(5000000, 1536)",1440020008,1440040016,0,0,0,0,40385488,80770976,0,0,4388567040,4388567040 309064,4,"ivf_pq::build(5000000, 1536)",1440020008,1440040016,0,0,0,0,53860384,94245872,0,0,4388567040,4388567040 334655,4,"ivf_pq::build(5000000, 1536)",1440020008,1440040016,0,0,0,0,67339295,107724783,0,0,4388567040,4388567040 385037,4,"ivf_pq::build(5000000, 1536)",1440020008,1440040016,0,0,0,0,74076743,114462231,0,0,4388567040,4388567040 386129,4,"ivf_pq::build(5000000, 1536)",1440020008,1440040016,0,0,0,0,80814199,121199687,0,0,4388567040,4388567040 402750,4,"ivf_pq::build(5000000, 1536)",1440020008,1440040016,0,0,0,0,46099768,126913967,0,0,4388567040,4388567040 ... ``` This can later be visualized (the visualization script is not included in the PR): <img width="2100" height="1350" alt="allocations" src="https://github.com/user-attachments/assets/3f0ab942-b49b-4e09-a0ea-9181725ae05e" /> #### Implementation overview ##### NVTX Added thread-local tracking of NVTX range stack; the calling thread shares a handle to the sampling thread to correlate the NVTX range state with allocations. ##### Memory resource adaptors - statistics adaptor: atomically counts allocations/deallocations for any `cuda::mr`-compatible resource - notifying adaptor: sets a shared "notifier" state on each event ##### Resource monitor A resource monitor registers a collection of resource statistics objects, a single NVTX range handle, and a single notifier state. It spawns a new thread to sample the resource statistics at a given rate (but only when the notifier is triggered). This thread writes to a CSV output stream. ##### Memory tracking resources `raft::memory_tracking_resources` is a child of `raft::resources`, thus can be used as a drop-in replacement. It replaces all known memory resource for the duration of its lifetime and manages the output file or stream if necessary. Depends on (and includes all changes of) rapidsai#2968 Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: rapidsai#2973
1 parent 5475667 commit eae8d47

File tree

11 files changed

+1147
-7
lines changed

11 files changed

+1147
-7
lines changed

cpp/bench/prims/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# =============================================================================
22
# cmake-format: off
3-
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
3+
# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
44
# SPDX-License-Identifier: Apache-2.0
55
# cmake-format: on
66
# =============================================================================
@@ -67,7 +67,7 @@ function(ConfigureBench)
6767
endfunction()
6868

6969
if(BUILD_PRIMS_BENCH)
70-
ConfigureBench(NAME CORE_BENCH PATH core/bitset.cu core/copy.cu main.cpp)
70+
ConfigureBench(NAME CORE_BENCH PATH core/bitset.cu core/copy.cu core/memory_tracking.cu main.cpp)
7171

7272
ConfigureBench(NAME UTIL_BENCH PATH util/popc.cu main.cpp)
7373

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#include <common/benchmark.hpp>
7+
8+
#include <raft/core/resource/cuda_stream.hpp>
9+
#include <raft/core/resource/device_memory_resource.hpp>
10+
#include <raft/core/resources.hpp>
11+
#include <raft/util/memory_tracking_resources.hpp>
12+
13+
#include <rmm/cuda_stream_view.hpp>
14+
#include <rmm/resource_ref.hpp>
15+
16+
#include <unistd.h>
17+
18+
#include <chrono>
19+
#include <cstdlib>
20+
#include <filesystem>
21+
#include <memory>
22+
#include <vector>
23+
24+
namespace raft::bench::core {
25+
26+
struct tracking_inputs {
27+
int num_allocs;
28+
size_t alloc_size;
29+
int64_t sample_rate_us;
30+
bool batch;
31+
};
32+
33+
struct tracking_overhead : public fixture {
34+
tracking_overhead(const tracking_inputs& p) : fixture(true), params(p)
35+
{
36+
if (p.sample_rate_us >= 0) {
37+
std::string tpl = (std::filesystem::temp_directory_path() / "raft_bench_XXXXXX").string();
38+
int fd = mkstemp(tpl.data());
39+
if (fd != -1) close(fd);
40+
tmp_path_ = std::move(tpl);
41+
tracked_res_.emplace(handle, tmp_path_, std::chrono::microseconds{p.sample_rate_us});
42+
}
43+
}
44+
45+
~tracking_overhead()
46+
{
47+
tracked_res_.reset();
48+
if (!tmp_path_.empty()) { std::remove(tmp_path_.c_str()); }
49+
}
50+
51+
void run_benchmark(::benchmark::State& state) override
52+
{
53+
state.counters["alloc_size"] = params.alloc_size;
54+
state.counters["sample_rate_us"] = params.sample_rate_us;
55+
state.counters["batch"] = params.batch;
56+
57+
run_allocs(state, tracked_res_ ? reinterpret_cast<raft::resources&>(*tracked_res_) : handle);
58+
59+
state.SetItemsProcessed(state.iterations() * params.num_allocs * 2);
60+
}
61+
62+
private:
63+
void run_allocs(::benchmark::State& state, raft::resources& res)
64+
{
65+
auto mr = raft::resource::get_workspace_resource_ref(res);
66+
auto sv = raft::resource::get_cuda_stream(res);
67+
68+
if (params.batch) {
69+
std::vector<void*> ptrs(params.num_allocs);
70+
for (auto _ : state) {
71+
auto t0 = std::chrono::high_resolution_clock::now();
72+
for (int i = 0; i < params.num_allocs; i++)
73+
ptrs[i] = mr.allocate(sv, params.alloc_size);
74+
for (int i = params.num_allocs - 1; i >= 0; i--)
75+
mr.deallocate(sv, ptrs[i], params.alloc_size);
76+
state.SetIterationTime(
77+
std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count());
78+
}
79+
} else {
80+
for (auto _ : state) {
81+
auto t0 = std::chrono::high_resolution_clock::now();
82+
for (int i = 0; i < params.num_allocs; i++) {
83+
void* p = mr.allocate(sv, params.alloc_size);
84+
mr.deallocate(sv, p, params.alloc_size);
85+
}
86+
state.SetIterationTime(
87+
std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count());
88+
}
89+
}
90+
}
91+
92+
tracking_inputs params;
93+
std::string tmp_path_;
94+
std::optional<raft::memory_tracking_resources> tracked_res_ = std::nullopt;
95+
};
96+
97+
const std::vector<tracking_inputs> inputs{
98+
// ping-pong (isolates per-call overhead, pool recycles same block)
99+
{10000, 256, -1, false},
100+
{10000, 256, 0, false},
101+
{10000, 256, 1, false},
102+
{10000, 256, 10, false},
103+
{10000, 256, 100, false},
104+
{10000, 1 << 20, -1, false},
105+
{10000, 1 << 20, 0, false},
106+
{10000, 1 << 20, 1, false},
107+
{10000, 1 << 20, 10, false},
108+
{10000, 1 << 20, 100, false},
109+
{1000, 1 << 26, -1, false},
110+
{1000, 1 << 26, 0, false},
111+
{1000, 1 << 26, 1, false},
112+
{1000, 1 << 26, 10, false},
113+
{1000, 1 << 26, 100, false},
114+
// batch (allocate all, then deallocate all)
115+
{10000, 256, -1, true},
116+
{10000, 256, 0, true},
117+
{10000, 256, 1, true},
118+
{10000, 256, 10, true},
119+
{10000, 256, 100, true},
120+
{1000, 1 << 20, -1, true},
121+
{1000, 1 << 20, 0, true},
122+
{1000, 1 << 20, 1, true},
123+
{1000, 1 << 20, 10, true},
124+
{1000, 1 << 20, 100, true},
125+
};
126+
127+
RAFT_BENCH_REGISTER(tracking_overhead, "", inputs);
128+
129+
} // namespace raft::bench::core

cpp/include/raft/core/detail/nvtx.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

66
#pragma once
77

8+
#include <raft/core/detail/nvtx_range_stack.hpp>
9+
810
#include <rmm/cuda_stream_view.hpp>
911

1012
#ifdef NVTX_ENABLED
@@ -146,6 +148,7 @@ inline void push_range_name(const char* name)
146148
event_attrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
147149
event_attrib.message.ascii = name;
148150
nvtxDomainRangePushEx(domain_store<Domain>::value(), &event_attrib);
151+
detail::range_name_stack_instance.push(name);
149152
}
150153

151154
template <typename Domain, typename... Args>
@@ -168,12 +171,13 @@ inline void push_range(const char* format, Args... args)
168171
template <typename Domain>
169172
inline void pop_range()
170173
{
174+
detail::range_name_stack_instance.pop();
171175
nvtxDomainRangePop(domain_store<Domain>::value());
172176
}
173177

174178
} // namespace raft::common::nvtx::detail
175179

176-
#else // NVTX_ENABLED
180+
#else // NVTX_ENABLED
177181

178182
namespace raft::common::nvtx::detail {
179183

@@ -188,5 +192,4 @@ inline void pop_range()
188192
}
189193

190194
} // namespace raft::common::nvtx::detail
191-
192195
#endif // NVTX_ENABLED
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
#pragma once
6+
7+
#include <cstddef>
8+
#include <memory>
9+
#include <mutex>
10+
#include <stack>
11+
#include <string>
12+
#include <utility>
13+
14+
namespace raft::common::nvtx {
15+
16+
namespace detail {
17+
struct nvtx_range_name_stack;
18+
} // namespace detail
19+
20+
/**
21+
* Shared, read-only handle to the current NVTX range name of a specific thread
22+
* (set internally by one thread, read publicly by zero or more threads).
23+
*/
24+
class current_range {
25+
friend detail::nvtx_range_name_stack;
26+
27+
public:
28+
/** Read the current range name and stack depth (safe to call from any thread). */
29+
auto get() const -> std::pair<std::string, std::size_t>
30+
{
31+
std::lock_guard lock(mu_);
32+
return {value_, depth_};
33+
}
34+
35+
operator std::string() const
36+
{
37+
std::lock_guard lock(mu_);
38+
return value_;
39+
}
40+
41+
private:
42+
mutable std::mutex mu_;
43+
std::string value_;
44+
std::size_t depth_{0};
45+
46+
void set(const char* name, std::size_t depth)
47+
{
48+
std::lock_guard lock(mu_);
49+
value_ = name ? name : "";
50+
depth_ = depth;
51+
}
52+
};
53+
54+
namespace detail {
55+
56+
struct nvtx_range_name_stack {
57+
void push(const char* name)
58+
{
59+
stack_.emplace(name);
60+
current_->set(name, stack_.size());
61+
}
62+
63+
void pop()
64+
{
65+
if (!stack_.empty()) { stack_.pop(); }
66+
current_->set(stack_.empty() ? nullptr : stack_.top().c_str(), stack_.size());
67+
}
68+
69+
auto current() const -> std::shared_ptr<const current_range> { return current_; }
70+
71+
private:
72+
std::stack<std::string> stack_{};
73+
std::shared_ptr<current_range> current_{std::make_shared<current_range>()};
74+
};
75+
76+
inline thread_local nvtx_range_name_stack range_name_stack_instance{};
77+
78+
} // namespace detail
79+
80+
/**
81+
* Get a read-only handle to this thread's current NVTX range name.
82+
* Pass the returned shared_ptr to another thread to read this thread's current NVTX range name at
83+
* any time.
84+
*/
85+
inline auto thread_local_current_range() -> std::shared_ptr<const current_range>
86+
{
87+
return detail::range_name_stack_instance.current();
88+
}
89+
90+
} // namespace raft::common::nvtx

0 commit comments

Comments
 (0)