@@ -37,13 +37,18 @@ int main() {
3737 if (!IH.ext_codeplay_has_graph ()) {
3838 assert (false && " Native Handle should have a graph" );
3939 }
40- // Newly created stream for this node
41- auto NativeStream = IH.get_native_queue <backend::ext_oneapi_cuda>();
4240 // Graph already created with cuGraphCreate
4341 CUgraph NativeGraph =
4442 IH.ext_codeplay_get_native_graph <backend::ext_oneapi_cuda>();
4543
4644 // Start stream capture
45+ // After CUDA 12.3 we can use cuStreamBeginCaptureToGraph to capture
46+ // the stream directly in the native graph, rather than needing to
47+ // instantiate the stream capture as a new graph.
48+ #if CUDA_VERSION >= 12030
49+ // Newly created stream for this node
50+ auto NativeStream = IH.get_native_queue <backend::ext_oneapi_cuda>();
51+
4752 auto Res = cuStreamBeginCaptureToGraph (NativeStream, NativeGraph, nullptr ,
4853 nullptr , 0 ,
4954 CU_STREAM_CAPTURE_MODE_GLOBAL);
@@ -68,6 +73,53 @@ int main() {
6873
6974 Res = cuStreamEndCapture (NativeStream, &NativeGraph);
7075 assert (Res == CUDA_SUCCESS);
76+ #else
77+ // Use explicit graph building API to add alloc/free nodes when
78+ // cuGraphAddMemFreeNode isn't available
79+ auto Device = IH.get_native_device <backend::ext_oneapi_cuda>();
80+ CUDA_MEM_ALLOC_NODE_PARAMS AllocParams{};
81+ AllocParams.bytesize = Size * sizeof (int32_t );
82+ AllocParams.poolProps .allocType = CU_MEM_ALLOCATION_TYPE_PINNED;
83+ AllocParams.poolProps .location .id = Device;
84+ AllocParams.poolProps .location .type = CU_MEM_LOCATION_TYPE_DEVICE;
85+ CUgraphNode AllocNode;
86+ auto Res = cuGraphAddMemAllocNode (&AllocNode, NativeGraph, nullptr , 0 ,
87+ &AllocParams);
88+ assert (Res == CUDA_SUCCESS);
89+
90+ CUdeviceptr PtrAsync = AllocParams.dptr ;
91+ CUDA_MEMSET_NODE_PARAMS MemsetParams{};
92+ MemsetParams.dst = PtrAsync;
93+ MemsetParams.elementSize = sizeof (int32_t );
94+ MemsetParams.height = Size;
95+ MemsetParams.pitch = sizeof (int32_t );
96+ MemsetParams.value = Pattern;
97+ MemsetParams.width = 1 ;
98+ CUgraphNode MemsetNode;
99+ CUcontext Context = IH.get_native_context <backend::ext_oneapi_cuda>();
100+ Res = cuGraphAddMemsetNode (&MemsetNode, NativeGraph, &AllocNode, 1 ,
101+ &MemsetParams, Context);
102+ assert (Res == CUDA_SUCCESS);
103+
104+ CUDA_MEMCPY3D MemcpyParams{};
105+ std::memset (&MemcpyParams, 0 , sizeof (CUDA_MEMCPY3D));
106+ MemcpyParams.srcMemoryType = CU_MEMORYTYPE_DEVICE;
107+ MemcpyParams.srcDevice = PtrAsync;
108+ MemcpyParams.dstMemoryType = CU_MEMORYTYPE_DEVICE;
109+ MemcpyParams.dstDevice = (CUdeviceptr)PtrX;
110+ MemcpyParams.WidthInBytes = Size * sizeof (int32_t );
111+ MemcpyParams.Height = 1 ;
112+ MemcpyParams.Depth = 1 ;
113+ CUgraphNode MemcpyNode;
114+ Res = cuGraphAddMemcpyNode (&MemcpyNode, NativeGraph, &MemsetNode, 1 ,
115+ &MemcpyParams, Context);
116+ assert (Res == CUDA_SUCCESS);
117+
118+ CUgraphNode FreeNode;
119+ Res = cuGraphAddMemFreeNode (&FreeNode, NativeGraph, &MemcpyNode, 1 ,
120+ PtrAsync);
121+ assert (Res == CUDA_SUCCESS);
122+ #endif
71123 });
72124 });
73125
0 commit comments