Skip to content

Commit fb3de98

Browse files
committed
Cleanup the stackable resource implementation
1 parent f78594f commit fb3de98

File tree

4 files changed

+111
-103
lines changed

4 files changed

+111
-103
lines changed

cudax/examples/stf/binary_fhe_stackable.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ int main()
224224
auto eA = pA.encrypt();
225225
auto eB = pB.encrypt();
226226

227-
ctx.push_graph();
227+
ctx.push();
228228

229229
eA.push(access_mode::read);
230230
eB.push(access_mode::read);

cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh

Lines changed: 107 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -34,139 +34,155 @@ namespace cuda::experimental::stf
3434
template <typename T>
3535
class stackable_logical_data;
3636

37+
/**
38+
* @brief This class defines a context that behaves as a context which can have nested subcontexts (implemented as local CUDA graphs)
39+
*/
3740
class stackable_ctx
3841
{
3942
public:
4043
class impl
4144
{
45+
private:
46+
/*
47+
* State of each nested context
48+
*/
49+
struct per_level {
50+
per_level(context ctx, cudaStream_t support_stream, ::std::optional<stream_adapter> alloc_adapters) : ctx(mv(ctx)), support_stream(mv(support_stream)), alloc_adapters(mv(alloc_adapters)) {}
51+
52+
context ctx;
53+
cudaStream_t support_stream;
54+
// A wrapper to forward allocations from a level to the previous one (none is used at the root level)
55+
::std::optional<stream_adapter> alloc_adapters;
56+
};
57+
4258
public:
4359
impl()
4460
{
45-
push(stream_ctx(), nullptr);
61+
push();
4662
}
4763

4864
~impl() = default;
4965

50-
void push(context ctx, cudaStream_t stream)
66+
/**
67+
* @brief Create a new nested level
68+
*/
69+
void push()
5170
{
52-
s.push_back(mv(ctx));
53-
s_stream.push_back(stream);
54-
}
71+
// These resources are not destroyed when we pop, so we create it only if needed
72+
if (async_handles.size() < levels.size())
73+
{
74+
async_handles.emplace_back();
75+
}
5576

56-
void pop()
57-
{
58-
s.back().finalize();
77+
if (levels.size() == 0) {
78+
levels.emplace_back(stream_ctx(), nullptr, ::std::nullopt);
79+
}
80+
else {
81+
// Get a stream from previous context (we haven't pushed the new one yet)
82+
cudaStream_t stream = levels[depth()].ctx.pick_stream();
5983

60-
s.pop_back();
84+
auto gctx = graph_ctx(stream, async_handles.back());
6185

62-
s_stream.pop_back();
86+
auto wrapper = stream_adapter(gctx, stream);
87+
// FIXME : issue with the deinit phase
88+
// gctx.update_uncached_allocator(wrapper.allocator());
6389

64-
_CCCL_ASSERT(alloc_adapters.size() > 0, "Calling pop from an empty container");
65-
alloc_adapters.back().clear();
66-
alloc_adapters.pop_back();
90+
levels.emplace_back(gctx, stream, wrapper);
91+
}
6792
}
6893

69-
size_t depth() const
94+
/**
95+
* @brief Terminate the current nested level and get back to the previous one
96+
*/
97+
void pop()
7098
{
71-
return s.size() - 1;
72-
}
99+
_CCCL_ASSERT(levels.size() > 0, "Calling pop while no context was pushed");
73100

74-
auto& get()
75-
{
76-
return s.back();
77-
}
101+
auto &current_level = levels.back();
78102

79-
const auto& get() const
80-
{
81-
return s.back();
103+
// Ensure everything is finished in the context
104+
current_level.ctx.finalize();
105+
106+
// Destroy the resources used in the wrapper allocator (if any)
107+
if (current_level.alloc_adapters.has_value())
108+
{
109+
current_level.alloc_adapters.value().clear();
110+
}
111+
112+
// Destroy the current level state
113+
levels.pop_back();
82114
}
83115

84-
auto& operator[](size_t level)
116+
/**
117+
* @brief Get the nesting depth
118+
*/
119+
size_t depth() const
85120
{
86-
_CCCL_ASSERT(level < s.size(), "Out of bound access");
87-
return s[level];
121+
return levels.size() - 1;
88122
}
89123

90-
const auto& operator[](size_t level) const
124+
/**
125+
* @brief Returns a reference to the context at a specific level
126+
*/
127+
auto& get_ctx(size_t level)
91128
{
92-
_CCCL_ASSERT(level < s.size(), "Out of bound access");
93-
return s[level];
129+
return levels[level].ctx;
94130
}
95131

96-
cudaStream_t stream_at(size_t level) const
132+
/**
133+
* @brief Returns a const reference to the context at a specific level
134+
*/
135+
const auto& get_ctx(size_t level) const
97136
{
98-
return s_stream[level];
137+
return levels[level].ctx;
99138
}
100139

101-
void push_graph()
140+
cudaStream_t get_stream(size_t level) const
102141
{
103-
cudaStream_t stream = get().pick_stream();
104-
105-
// These resources are not destroyed when we pop, so we create it only if needed
106-
if (async_handles.size() < s_stream.size())
107-
{
108-
async_handles.emplace_back();
109-
}
110-
111-
auto gctx = graph_ctx(stream, async_handles.back());
112-
113-
auto wrapper = stream_adapter(gctx, stream);
114-
// FIXME : issue with the deinit phase
115-
// gctx.update_uncached_allocator(wrapper.allocator());
116-
117-
alloc_adapters.push_back(wrapper);
118-
119-
push(mv(gctx), stream);
142+
return levels[level].support_stream;
120143
}
121144

122145
private:
123-
::std::vector<context> s;
124-
::std::vector<cudaStream_t> s_stream;
146+
// State for each nested level
147+
::std::vector<per_level> levels;
148+
149+
// Handles to retain some asynchronous states, we maintain it separately
150+
// from levels because we keep its entries even when we pop a level
125151
::std::vector<async_resources_handle> async_handles;
126-
::std::vector<stream_adapter> alloc_adapters;
127152
};
128153

129154
stackable_ctx()
130155
: pimpl(::std::make_shared<impl>())
131156
{}
132157

133-
const auto& get() const
158+
cudaStream_t get_stream(size_t level) const
134159
{
135-
return pimpl->get();
136-
}
137-
auto& get()
138-
{
139-
return pimpl->get();
140-
}
141-
142-
auto& operator[](size_t level)
143-
{
144-
return pimpl->operator[](level);
160+
return pimpl->get_stream(level);
145161
}
146162

147-
const auto& operator[](size_t level) const
163+
const auto& get_ctx(size_t level) const
148164
{
149-
return pimpl->operator[](level);
165+
return pimpl->get_ctx(level);
150166
}
151167

152-
cudaStream_t stream_at(size_t level) const
168+
auto& get_ctx(size_t level)
153169
{
154-
return pimpl->stream_at(level);
170+
return pimpl->get_ctx(level);
155171
}
156172

157173
const auto& operator()() const
158174
{
159-
return get();
175+
return get_ctx(depth());
160176
}
161177

162178
auto& operator()()
163179
{
164-
return get();
180+
return get_ctx(depth());
165181
}
166182

167-
void push_graph()
183+
void push()
168184
{
169-
pimpl->push_graph();
185+
pimpl->push();
170186
}
171187

172188
void pop()
@@ -182,33 +198,33 @@ public:
182198
template <typename... Pack>
183199
auto logical_data(Pack&&... pack)
184200
{
185-
return stackable_logical_data(*this, depth(), get().logical_data(::std::forward<Pack>(pack)...));
201+
return stackable_logical_data(*this, depth(), get_ctx(depth()).logical_data(::std::forward<Pack>(pack)...));
186202
}
187203

188204
template <typename... Pack>
189205
auto task(Pack&&... pack)
190206
{
191-
return get().task(::std::forward<Pack>(pack)...);
207+
return get_ctx(depth()).task(::std::forward<Pack>(pack)...);
192208
}
193209

194210
template <typename... Pack>
195211
auto parallel_for(Pack&&... pack)
196212
{
197-
return get().parallel_for(::std::forward<Pack>(pack)...);
213+
return get_ctx(depth()).parallel_for(::std::forward<Pack>(pack)...);
198214
}
199215

200216
template <typename... Pack>
201217
auto host_launch(Pack&&... pack)
202218
{
203-
return get().host_launch(::std::forward<Pack>(pack)...);
219+
return get_ctx(depth()).host_launch(::std::forward<Pack>(pack)...);
204220
}
205221

206222
void finalize()
207223
{
208224
// There must be only one level left
209225
_CCCL_ASSERT(depth() == 0, "All nested levels must have been popped");
210226

211-
get().finalize();
227+
get_ctx(depth()).finalize();
212228
}
213229

214230
public:
@@ -229,12 +245,12 @@ class stackable_logical_data
229245
s.push_back(ld);
230246
}
231247

232-
const auto& get() const
248+
const auto& get_ld() const
233249
{
234250
check_level_mismatch();
235251
return s.back();
236252
}
237-
auto& get()
253+
auto& get_ld()
238254
{
239255
check_level_mismatch();
240256
return s.back();
@@ -243,8 +259,8 @@ class stackable_logical_data
243259
void push(access_mode m, data_place where = data_place::invalid)
244260
{
245261
// We have not pushed yet, so the current depth is the one before pushing
246-
context& from_ctx = sctx[depth()];
247-
context& to_ctx = sctx[depth() + 1];
262+
context& from_ctx = sctx.get_ctx(depth());
263+
context& to_ctx = sctx.get_ctx(depth() + 1);
248264

249265
// Ensure this will match the depth of the context after pushing
250266
_CCCL_ASSERT(sctx.depth() == depth() + 1, "Invalid depth");
@@ -265,7 +281,7 @@ class stackable_logical_data
265281
frozen_s.push_back(f);
266282

267283
// FAKE IMPORT : use the stream needed to support the (graph) ctx
268-
cudaStream_t stream = sctx.stream_at(depth());
284+
cudaStream_t stream = sctx.get_stream(depth());
269285

270286
T inst = f.get(where, stream);
271287
auto ld = to_ctx.logical_data(inst, where);
@@ -282,7 +298,7 @@ class stackable_logical_data
282298
{
283299
// We are going to unfreeze the data, which is currently being used
284300
// in a (graph) ctx that uses this stream to launch the graph
285-
cudaStream_t stream = sctx.stream_at(depth());
301+
cudaStream_t stream = sctx.get_stream(depth());
286302

287303
frozen_logical_data<T>& f = frozen_s.back();
288304
f.unfreeze(stream);
@@ -336,22 +352,13 @@ public:
336352
: pimpl(::std::make_shared<impl>(mv(sctx), depth, mv(ld)))
337353
{}
338354

339-
const auto& get() const
340-
{
341-
return pimpl->get();
342-
}
343-
auto& get()
344-
{
345-
return pimpl->get();
346-
}
347-
348-
const auto& operator()() const
355+
const auto& get_ld() const
349356
{
350-
return get();
357+
return pimpl->get_ld();
351358
}
352-
auto& operator()()
359+
auto& get_ld()
353360
{
354-
return get();
361+
return pimpl->get_ld();
355362
}
356363

357364
size_t depth() const
@@ -363,6 +370,7 @@ public:
363370
{
364371
pimpl->push(m, mv(where));
365372
}
373+
366374
void pop()
367375
{
368376
pimpl->pop();
@@ -372,24 +380,24 @@ public:
372380
template <typename... Pack>
373381
auto read(Pack&&... pack) const
374382
{
375-
return get().read(::std::forward<Pack>(pack)...);
383+
return get_ld().read(::std::forward<Pack>(pack)...);
376384
}
377385

378386
template <typename... Pack>
379387
auto write(Pack&&... pack)
380388
{
381-
return get().write(::std::forward<Pack>(pack)...);
389+
return get_ld().write(::std::forward<Pack>(pack)...);
382390
}
383391

384392
template <typename... Pack>
385393
auto rw(Pack&&... pack)
386394
{
387-
return get().rw(::std::forward<Pack>(pack)...);
395+
return get_ld().rw(::std::forward<Pack>(pack)...);
388396
}
389397

390398
auto shape() const
391399
{
392-
return get().shape();
400+
return get_ld().shape();
393401
}
394402

395403
auto& set_symbol(::std::string symbol)

0 commit comments

Comments
 (0)