Skip to content

Commit 5f2e26f

Browse files
authored
[STF] Make green context places independent from async_resources_handle (#7937)
* Introduce a PIMPL idiom in the stream_pool class * Extract stream related utilities from async_resources_handle.cuh, and from utility/stream_to_dev.cuh to put them in a header located in places/ * Restore some methods which were erased by mistake * Fix compilation * - stream pools no longer belong to the async_resources_handle but are stored directly along with the execution places - create an actual exec_place_device class instead of relying on base exec_place type * remove get_stream_from_pool which is unnecessary * Fix where pool_size variable gets read from * async_resources_handle.cuh is not needed in places.cuh anymore * use the appropriate (sufficient) header * Make green context place implementation independent from async_resources_handle * only put green context methods with CUDA 12.4+ * Add missing header
1 parent 09b7323 commit 5f2e26f

File tree

12 files changed

+119
-155
lines changed

12 files changed

+119
-155
lines changed

cudax/include/cuda/experimental/__stf/graph/graph_task.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public:
103103
if (is_capture_enabled())
104104
{
105105
// Select a stream from the pool
106-
capture_stream = get_exec_place().getStream(ctx.async_resources(), true).stream;
106+
capture_stream = get_exec_place().getStream(true).stream;
107107
// Use relaxed capture mode to allow capturing workloads that lazily initialize
108108
// resources (e.g., set up memory pools)
109109
cuda_safe_call(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeRelaxed));
@@ -366,7 +366,7 @@ public:
366366
//
367367

368368
// Get a stream from the pool associated to the execution place
369-
capture_stream = get_exec_place().getStream(ctx.async_resources(), true).stream;
369+
capture_stream = get_exec_place().getStream(true).stream;
370370

371371
cudaGraph_t childGraph = nullptr;
372372
// Use relaxed capture mode to allow capturing workloads that lazily initialize
@@ -628,7 +628,7 @@ public:
628628
auto lock = lock_ctx_graph();
629629

630630
// Get a stream from the pool associated to the execution place
631-
cudaStream_t capture_stream = get_exec_place().getStream(ctx.async_resources(), true).stream;
631+
cudaStream_t capture_stream = get_exec_place().getStream(true).stream;
632632

633633
cudaGraph_t childGraph = nullptr;
634634
// Use relaxed capture mode to allow capturing workloads that lazily initialize

cudax/include/cuda/experimental/__stf/internal/async_resources_handle.cuh

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
#include <cuda/experimental/__stf/internal/exec_affinity.cuh>
2929
#include <cuda/experimental/__stf/internal/executable_graph_cache.cuh>
30-
#include <cuda/experimental/__stf/places/stream_pool.cuh>
30+
#include <cuda/experimental/__stf/places/exec/green_context.cuh>
3131
#include <cuda/experimental/__stf/utility/core.cuh>
3232
#include <cuda/experimental/__stf/utility/cuda_safe_call.cuh>
3333
#include <cuda/experimental/__stf/utility/hash.cuh> // for ::std::hash<::std::pair<::std::ptrdiff_t, ::std::ptrdiff_t>>
@@ -40,8 +40,6 @@
4040

4141
namespace cuda::experimental::stf
4242
{
43-
class green_context_helper;
44-
4543
/**
4644
* @brief A handle which stores resources useful for an efficient asynchronous
4745
* execution. For example this will store the pools of CUDA streams.
@@ -52,12 +50,6 @@ class green_context_helper;
5250
*/
5351
class async_resources_handle
5452
{
55-
// TODO: optimize based on measurements
56-
57-
public:
58-
static constexpr size_t pool_size = 4;
59-
static constexpr size_t data_pool_size = 4;
60-
6153
private:
6254
/**
6355
* @brief This class implements a matrix to keep track of the previous
@@ -118,43 +110,25 @@ private:
118110
class impl
119111
{
120112
public:
113+
#if _CCCL_CTK_AT_LEAST(12, 4)
121114
impl()
122115
{
123116
const int ndevices = cuda_try<cudaGetDeviceCount>();
124-
assert(ndevices > 0);
125-
assert(pool_size > 0);
126-
assert(data_pool_size > 0);
127-
117+
_CCCL_ASSERT(ndevices > 0, "invalid device count");
128118
per_device_gc_helper.resize(ndevices, nullptr);
129-
/* For every device, we keep two pools, one dedicated to computation,
130-
* the other for auxiliary methods such as data transfers. This is intended to
131-
* improve overlapping of transfers and computation, for example. */
132-
pool.reserve(ndevices);
133-
for (auto d : each(ndevices))
134-
{
135-
::std::ignore = d;
136-
pool.emplace_back(stream_pool(pool_size), stream_pool(data_pool_size));
137-
}
138-
}
139-
140-
stream_pool& get_device_stream_pool(int dev_id, bool for_computation)
141-
{
142-
assert(dev_id < int(pool.size()));
143-
return for_computation ? pool[dev_id].first : pool[dev_id].second;
144119
}
120+
#endif // _CCCL_CTK_AT_LEAST(12, 4)
145121

146122
public:
147123
// This memorize what was the last event used to synchronize a pair of streams
148124
last_event_per_stream cached_syncs;
149125

150-
// For each device, a pair of stream_pool objects, each stream_pool objects
151-
// stores a pool of streams on this device
152-
::std::vector<::std::pair<stream_pool, stream_pool>> pool;
153-
154126
/* Store previously instantiated graphs, indexed by the number of edges and nodes */
155127
executable_graph_cache cached_graphs;
156128

129+
#if _CCCL_CTK_AT_LEAST(12, 4)
157130
::std::vector<::std::shared_ptr<green_context_helper>> per_device_gc_helper;
131+
#endif // _CCCL_CTK_AT_LEAST(12, 4)
158132

159133
mutable exec_affinity affinity;
160134
};
@@ -173,12 +147,6 @@ public:
173147
return pimpl != nullptr;
174148
}
175149

176-
stream_pool& get_device_stream_pool(int dev_id, bool for_computation) const
177-
{
178-
assert(pimpl);
179-
return pimpl->get_device_stream_pool(dev_id, for_computation);
180-
}
181-
182150
bool validate_sync_and_update(unsigned long long dst, unsigned long long src, int event_id)
183151
{
184152
assert(pimpl);
@@ -192,6 +160,7 @@ public:
192160
return pimpl->cached_graphs.query(nnodes, nedges, mv(g));
193161
}
194162

163+
#if _CCCL_CTK_AT_LEAST(12, 4)
195164
// Get the green context helper cached for this device (or let the user initialize it)
196165
auto& gc_helper(int dev_id)
197166
{
@@ -201,7 +170,17 @@ public:
201170
}
202171

203172
// Get green context helper with lazy initialization
204-
::std::shared_ptr<green_context_helper> get_gc_helper(int dev_id, int sm_count);
173+
::std::shared_ptr<green_context_helper> get_gc_helper(int dev_id, int sm_count)
174+
{
175+
assert(pimpl);
176+
assert(dev_id < int(pimpl->per_device_gc_helper.size()));
177+
auto& h = pimpl->per_device_gc_helper[dev_id];
178+
if (!h)
179+
{
180+
h = ::std::make_shared<green_context_helper>(sm_count, dev_id);
181+
}
182+
return h;
183+
}
205184

206185
// Register an external green context helper
207186
void register_gc_helper(int dev_id, ::std::shared_ptr<green_context_helper> helper)
@@ -210,6 +189,7 @@ public:
210189
assert(dev_id < int(pimpl->per_device_gc_helper.size()));
211190
pimpl->per_device_gc_helper[dev_id] = ::std::move(helper);
212191
}
192+
#endif // _CCCL_CTK_AT_LEAST(12, 4)
213193

214194
exec_affinity& get_affinity()
215195
{

cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ public:
910910
auto pick_dstream()
911911
{
912912
exec_place p = default_exec_place();
913-
return p.get_stream_pool(async_resources(), true).next(p);
913+
return p.get_stream_pool(true).next(p);
914914
}
915915
cudaStream_t pick_stream()
916916
{

cudax/include/cuda/experimental/__stf/places/exec/cuda_stream.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# pragma system_header
2626
#endif // no system header
2727

28-
#include <cuda/experimental/__stf/internal/backend_ctx.cuh>
28+
#include <cuda/experimental/__stf/places/places.cuh>
2929

3030
namespace cuda::experimental::stf
3131
{
@@ -56,7 +56,7 @@ public:
5656
return exec_place::device(dstream.dev_id).deactivate(prev);
5757
}
5858

59-
stream_pool& get_stream_pool(async_resources_handle&, bool) const override
59+
stream_pool& get_stream_pool(bool) const override
6060
{
6161
return dummy_pool;
6262
}

0 commit comments

Comments
 (0)