@@ -34,139 +34,155 @@ namespace cuda::experimental::stf
3434template <typename T>
3535class stackable_logical_data ;
3636
37+ /* *
38+ * @brief This class defines a context that behaves as a context which can have nested subcontexts (implemented as local CUDA graphs)
39+ */
3740class stackable_ctx
3841{
3942public:
4043 class impl
4144 {
45+ private:
46+ /*
47+ * State of each nested context
48+ */
49+ struct per_level {
50+ per_level (context ctx, cudaStream_t support_stream, ::std::optional<stream_adapter> alloc_adapters) : ctx(mv(ctx)), support_stream(mv(support_stream)), alloc_adapters(mv(alloc_adapters)) {}
51+
52+ context ctx;
53+ cudaStream_t support_stream;
54+ // A wrapper to forward allocations from a level to the previous one (none is used at the root level)
55+ ::std::optional<stream_adapter> alloc_adapters;
56+ };
57+
4258 public:
4359 impl ()
4460 {
45- push (stream_ctx (), nullptr );
61+ push ();
4662 }
4763
4864 ~impl () = default ;
4965
50- void push (context ctx, cudaStream_t stream)
66+ /* *
67+ * @brief Create a new nested level
68+ */
69+ void push ()
5170 {
52- s.push_back (mv (ctx));
53- s_stream.push_back (stream);
54- }
71+ // These resources are not destroyed when we pop, so we create it only if needed
72+ if (async_handles.size () < levels.size ())
73+ {
74+ async_handles.emplace_back ();
75+ }
5576
56- void pop ()
57- {
58- s.back ().finalize ();
77+ if (levels.size () == 0 ) {
78+ levels.emplace_back (stream_ctx (), nullptr , ::std::nullopt );
79+ }
80+ else {
81+ // Get a stream from previous context (we haven't pushed the new one yet)
82+ cudaStream_t stream = levels[depth ()].ctx .pick_stream ();
5983
60- s. pop_back ( );
84+ auto gctx = graph_ctx (stream, async_handles. back () );
6185
62- s_stream.pop_back ();
86+ auto wrapper = stream_adapter (gctx, stream);
87+ // FIXME : issue with the deinit phase
88+ // gctx.update_uncached_allocator(wrapper.allocator());
6389
64- _CCCL_ASSERT (alloc_adapters.size () > 0 , " Calling pop from an empty container" );
65- alloc_adapters.back ().clear ();
66- alloc_adapters.pop_back ();
90+ levels.emplace_back (gctx, stream, wrapper);
91+ }
6792 }
6893
69- size_t depth () const
94+ /* *
95+ * @brief Terminate the current nested level and get back to the previous one
96+ */
97+ void pop ()
7098 {
71- return s.size () - 1 ;
72- }
99+ _CCCL_ASSERT (levels.size () > 0 , " Calling pop while no context was pushed" );
73100
74- auto & get ()
75- {
76- return s.back ();
77- }
101+ auto ¤t_level = levels.back ();
78102
79- const auto & get () const
80- {
81- return s.back ();
103+ // Ensure everything is finished in the context
104+ current_level.ctx .finalize ();
105+
106+ // Destroy the resources used in the wrapper allocator (if any)
107+ if (current_level.alloc_adapters .has_value ())
108+ {
109+ current_level.alloc_adapters .value ().clear ();
110+ }
111+
112+ // Destroy the current level state
113+ levels.pop_back ();
82114 }
83115
84- auto & operator [](size_t level)
116+ /* *
117+ * @brief Get the nesting depth
118+ */
119+ size_t depth () const
85120 {
86- _CCCL_ASSERT (level < s.size (), " Out of bound access" );
87- return s[level];
121+ return levels.size () - 1 ;
88122 }
89123
90- const auto & operator [](size_t level) const
124+ /* *
125+ * @brief Returns a reference to the context at a specific level
126+ */
127+ auto & get_ctx (size_t level)
91128 {
92- _CCCL_ASSERT (level < s.size (), " Out of bound access" );
93- return s[level];
129+ return levels[level].ctx ;
94130 }
95131
96- cudaStream_t stream_at (size_t level) const
132+ /* *
133+ * @brief Returns a const reference to the context at a specific level
134+ */
135+ const auto & get_ctx (size_t level) const
97136 {
98- return s_stream [level];
137+ return levels [level]. ctx ;
99138 }
100139
101- void push_graph ()
140+ cudaStream_t get_stream ( size_t level) const
102141 {
103- cudaStream_t stream = get ().pick_stream ();
104-
105- // These resources are not destroyed when we pop, so we create it only if needed
106- if (async_handles.size () < s_stream.size ())
107- {
108- async_handles.emplace_back ();
109- }
110-
111- auto gctx = graph_ctx (stream, async_handles.back ());
112-
113- auto wrapper = stream_adapter (gctx, stream);
114- // FIXME : issue with the deinit phase
115- // gctx.update_uncached_allocator(wrapper.allocator());
116-
117- alloc_adapters.push_back (wrapper);
118-
119- push (mv (gctx), stream);
142+ return levels[level].support_stream ;
120143 }
121144
122145 private:
123- ::std::vector<context> s;
124- ::std::vector<cudaStream_t> s_stream;
146+ // State for each nested level
147+ ::std::vector<per_level> levels;
148+
149+ // Handles to retain some asynchronous states, we maintain it separately
150+ // from levels because we keep its entries even when we pop a level
125151 ::std::vector<async_resources_handle> async_handles;
126- ::std::vector<stream_adapter> alloc_adapters;
127152 };
128153
129154 stackable_ctx ()
130155 : pimpl(::std::make_shared<impl>())
131156 {}
132157
133- const auto & get ( ) const
158+ cudaStream_t get_stream ( size_t level ) const
134159 {
135- return pimpl->get ();
136- }
137- auto & get ()
138- {
139- return pimpl->get ();
140- }
141-
142- auto & operator [](size_t level)
143- {
144- return pimpl->operator [](level);
160+ return pimpl->get_stream (level);
145161 }
146162
147- const auto & operator [] (size_t level) const
163+ const auto & get_ctx (size_t level) const
148164 {
149- return pimpl->operator [] (level);
165+ return pimpl->get_ctx (level);
150166 }
151167
152- cudaStream_t stream_at (size_t level) const
168+ auto & get_ctx (size_t level)
153169 {
154- return pimpl->stream_at (level);
170+ return pimpl->get_ctx (level);
155171 }
156172
157173 const auto & operator ()() const
158174 {
159- return get ( );
175+ return get_ctx ( depth () );
160176 }
161177
162178 auto & operator ()()
163179 {
164- return get ( );
180+ return get_ctx ( depth () );
165181 }
166182
167- void push_graph ()
183+ void push ()
168184 {
169- pimpl->push_graph ();
185+ pimpl->push ();
170186 }
171187
172188 void pop ()
@@ -182,33 +198,33 @@ public:
182198 template <typename ... Pack>
183199 auto logical_data (Pack&&... pack)
184200 {
185- return stackable_logical_data (*this , depth (), get ( ).logical_data (::std::forward<Pack>(pack)...));
201+ return stackable_logical_data (*this , depth (), get_ctx ( depth () ).logical_data (::std::forward<Pack>(pack)...));
186202 }
187203
188204 template <typename ... Pack>
189205 auto task (Pack&&... pack)
190206 {
191- return get ( ).task (::std::forward<Pack>(pack)...);
207+ return get_ctx ( depth () ).task (::std::forward<Pack>(pack)...);
192208 }
193209
194210 template <typename ... Pack>
195211 auto parallel_for (Pack&&... pack)
196212 {
197- return get ( ).parallel_for (::std::forward<Pack>(pack)...);
213+ return get_ctx ( depth () ).parallel_for (::std::forward<Pack>(pack)...);
198214 }
199215
200216 template <typename ... Pack>
201217 auto host_launch (Pack&&... pack)
202218 {
203- return get ( ).host_launch (::std::forward<Pack>(pack)...);
219+ return get_ctx ( depth () ).host_launch (::std::forward<Pack>(pack)...);
204220 }
205221
206222 void finalize ()
207223 {
208224 // There must be only one level left
209225 _CCCL_ASSERT (depth () == 0 , " All nested levels must have been popped" );
210226
211- get ( ).finalize ();
227+ get_ctx ( depth () ).finalize ();
212228 }
213229
214230public:
@@ -229,12 +245,12 @@ class stackable_logical_data
229245 s.push_back (ld);
230246 }
231247
232- const auto & get () const
248+ const auto & get_ld () const
233249 {
234250 check_level_mismatch ();
235251 return s.back ();
236252 }
237- auto & get ()
253+ auto & get_ld ()
238254 {
239255 check_level_mismatch ();
240256 return s.back ();
@@ -243,8 +259,8 @@ class stackable_logical_data
243259 void push (access_mode m, data_place where = data_place::invalid)
244260 {
245261 // We have not pushed yet, so the current depth is the one before pushing
246- context& from_ctx = sctx[ depth ()] ;
247- context& to_ctx = sctx[ depth () + 1 ] ;
262+ context& from_ctx = sctx. get_ctx ( depth ()) ;
263+ context& to_ctx = sctx. get_ctx ( depth () + 1 ) ;
248264
249265 // Ensure this will match the depth of the context after pushing
250266 _CCCL_ASSERT (sctx.depth () == depth () + 1 , " Invalid depth" );
@@ -265,7 +281,7 @@ class stackable_logical_data
265281 frozen_s.push_back (f);
266282
267283 // FAKE IMPORT : use the stream needed to support the (graph) ctx
268- cudaStream_t stream = sctx.stream_at (depth ());
284+ cudaStream_t stream = sctx.get_stream (depth ());
269285
270286 T inst = f.get (where, stream);
271287 auto ld = to_ctx.logical_data (inst, where);
@@ -282,7 +298,7 @@ class stackable_logical_data
282298 {
283299 // We are going to unfreeze the data, which is currently being used
284300 // in a (graph) ctx that uses this stream to launch the graph
285- cudaStream_t stream = sctx.stream_at (depth ());
301+ cudaStream_t stream = sctx.get_stream (depth ());
286302
287303 frozen_logical_data<T>& f = frozen_s.back ();
288304 f.unfreeze (stream);
@@ -336,22 +352,13 @@ public:
336352 : pimpl(::std::make_shared<impl>(mv(sctx), depth, mv(ld)))
337353 {}
338354
339- const auto & get () const
340- {
341- return pimpl->get ();
342- }
343- auto & get ()
344- {
345- return pimpl->get ();
346- }
347-
348- const auto & operator ()() const
355+ const auto & get_ld () const
349356 {
350- return get ();
357+ return pimpl-> get_ld ();
351358 }
352- auto & operator () ()
359+ auto & get_ld ()
353360 {
354- return get ();
361+ return pimpl-> get_ld ();
355362 }
356363
357364 size_t depth () const
@@ -363,6 +370,7 @@ public:
363370 {
364371 pimpl->push (m, mv (where));
365372 }
373+
366374 void pop ()
367375 {
368376 pimpl->pop ();
@@ -372,24 +380,24 @@ public:
372380 template <typename ... Pack>
373381 auto read (Pack&&... pack) const
374382 {
375- return get ().read (::std::forward<Pack>(pack)...);
383+ return get_ld ().read (::std::forward<Pack>(pack)...);
376384 }
377385
378386 template <typename ... Pack>
379387 auto write (Pack&&... pack)
380388 {
381- return get ().write (::std::forward<Pack>(pack)...);
389+ return get_ld ().write (::std::forward<Pack>(pack)...);
382390 }
383391
384392 template <typename ... Pack>
385393 auto rw (Pack&&... pack)
386394 {
387- return get ().rw (::std::forward<Pack>(pack)...);
395+ return get_ld ().rw (::std::forward<Pack>(pack)...);
388396 }
389397
390398 auto shape () const
391399 {
392- return get ().shape ();
400+ return get_ld ().shape ();
393401 }
394402
395403 auto & set_symbol (::std::string symbol)
0 commit comments