@@ -26,6 +26,8 @@ Functional ::
2626 2) Pass nodes as nullptr and verify api returns actual number of root nodes added to graph.
2727 3) If NumRootNodes passed is greater than the actual number of root nodes, the remaining entries in
2828 nodes list will be set to NULL, and the number of nodes actually obtained will be returned in NumRootNodes.
29+ 4) Create a graph with stream capture done on multiple dependent streams.
30+ Verify root nodes of created graph are matching the operations pushed which doesn't have dependencies.
2931
3032Argument Validation ::
3133 1) Pass graph as nullptr and verify api returns error code.
@@ -50,6 +52,7 @@ TEST_CASE("Unit_hipGraphGetRootNodes_Functional") {
5052 constexpr auto addlEntries = 5 ;
5153 hipGraph_t graph;
5254
55+
5356 hipGraphNode_t memcpyNode, kernelNode;
5457 hipKernelNodeParams kernelNodeParams{};
5558 hipStream_t streamForGraph;
@@ -141,6 +144,112 @@ TEST_CASE("Unit_hipGraphGetRootNodes_Functional") {
141144 free (rootnodes);
142145}
143146
147+
148+ /* *
149+ * Create a graph with stream capture done on multiple dependent streams. Verify root nodes
150+ * of created graph are matching the operations pushed which doesn't have dependencies.
151+ */
152+ TEST_CASE (" Unit_hipGraphGetRootNodes_CapturedStream" ) {
153+ hipStream_t stream1{nullptr }, stream2{nullptr }, mstream{nullptr };
154+ hipStream_t streamForGraph{nullptr };
155+ hipEvent_t memsetEvent1, memsetEvent2, forkStreamEvent;
156+ hipGraph_t graph{nullptr };
157+ hipGraphExec_t graphExec{nullptr };
158+ constexpr unsigned blocks = 512 ;
159+ constexpr unsigned threadsPerBlock = 256 ;
160+ constexpr size_t N = 1000000 ;
161+ constexpr int numMemsetNodes = 2 ;
162+ size_t Nbytes = N * sizeof (float ), numRootNodes{};
163+ float *A_d, *C_d;
164+ float *A_h, *C_h;
165+ A_h = reinterpret_cast <float *>(malloc (Nbytes));
166+ C_h = reinterpret_cast <float *>(malloc (Nbytes));
167+ REQUIRE (A_h != nullptr );
168+ REQUIRE (C_h != nullptr );
169+ HIP_CHECK (hipMalloc (&A_d, Nbytes));
170+ HIP_CHECK (hipMalloc (&C_d, Nbytes));
171+ REQUIRE (A_d != nullptr );
172+ REQUIRE (C_d != nullptr );
173+
174+ HIP_CHECK (hipStreamCreate (&streamForGraph));
175+
176+ // Initialize input buffer
177+ for (size_t i = 0 ; i < N; ++i) {
178+ A_h[i] = 3 .146f + i; // Pi
179+ }
180+
181+ HIP_CHECK (hipStreamCreate (&stream1));
182+ HIP_CHECK (hipStreamCreate (&stream2));
183+ HIP_CHECK (hipStreamCreate (&mstream));
184+ HIP_CHECK (hipEventCreate (&memsetEvent1));
185+ HIP_CHECK (hipEventCreate (&memsetEvent2));
186+ HIP_CHECK (hipEventCreate (&forkStreamEvent));
187+ HIP_CHECK (hipStreamBeginCapture (mstream, hipStreamCaptureModeGlobal));
188+ HIP_CHECK (hipEventRecord (forkStreamEvent, mstream));
189+ HIP_CHECK (hipStreamWaitEvent (stream1, forkStreamEvent, 0 ));
190+ HIP_CHECK (hipStreamWaitEvent (stream2, forkStreamEvent, 0 ));
191+ HIP_CHECK (hipMemsetAsync (A_d, 0 , Nbytes, stream1));
192+ HIP_CHECK (hipEventRecord (memsetEvent1, stream1));
193+ HIP_CHECK (hipMemsetAsync (C_d, 0 , Nbytes, stream2));
194+ HIP_CHECK (hipEventRecord (memsetEvent2, stream2));
195+ HIP_CHECK (hipStreamWaitEvent (mstream, memsetEvent1, 0 ));
196+ HIP_CHECK (hipStreamWaitEvent (mstream, memsetEvent2, 0 ));
197+ HIP_CHECK (hipMemcpyAsync (A_d, A_h, Nbytes, hipMemcpyHostToDevice, mstream));
198+ hipLaunchKernelGGL (HipTest::vector_square, dim3 (blocks),
199+ dim3 (threadsPerBlock), 0 , mstream, A_d, C_d, N);
200+ HIP_CHECK (hipMemcpyAsync (C_h, C_d, Nbytes, hipMemcpyDeviceToHost, mstream));
201+ HIP_CHECK (hipStreamEndCapture (mstream, &graph));
202+
203+ // Verify numof root nodes
204+ HIP_CHECK (hipGraphGetRootNodes (graph, nullptr , &numRootNodes));
205+ REQUIRE (numRootNodes == numMemsetNodes);
206+ INFO (" Num of nodes returned by GetRootNodes : " << numRootNodes);
207+
208+ int numBytes = sizeof (hipGraphNode_t) * numRootNodes;
209+ hipGraphNode_t* nodes = reinterpret_cast <hipGraphNode_t *>(malloc (numBytes));
210+ REQUIRE (nodes != nullptr );
211+
212+ hipGraphNodeType nodeType;
213+ HIP_CHECK (hipGraphGetRootNodes (graph, nodes, &numRootNodes));
214+ REQUIRE (numRootNodes == numMemsetNodes);
215+
216+ // Verify root nodes returned are memset nodes.
217+ HIP_CHECK (hipGraphNodeGetType (nodes[0 ], &nodeType));
218+ REQUIRE (nodeType == hipGraphNodeTypeMemset);
219+ HIP_CHECK (hipGraphNodeGetType (nodes[1 ], &nodeType));
220+ REQUIRE (nodeType == hipGraphNodeTypeMemset);
221+
222+ // Instantiate and launch the graph
223+ HIP_CHECK (hipGraphInstantiate (&graphExec, graph, NULL , NULL , 0 ));
224+ HIP_CHECK (hipGraphLaunch (graphExec, streamForGraph));
225+ HIP_CHECK (hipStreamSynchronize (streamForGraph));
226+
227+ // Validate the computation
228+ for (size_t i = 0 ; i < N; i++) {
229+ if (C_h[i] != A_h[i] * A_h[i]) {
230+ INFO (" A and C not matching at " << i << " C_h[i] " << C_h[i]
231+ << " A_h[i] " << A_h[i]);
232+ REQUIRE (false );
233+ }
234+ }
235+
236+ HIP_CHECK (hipGraphExecDestroy (graphExec));
237+ HIP_CHECK (hipGraphDestroy (graph));
238+ HIP_CHECK (hipStreamDestroy (streamForGraph));
239+ HIP_CHECK (hipStreamDestroy (mstream));
240+ HIP_CHECK (hipStreamDestroy (stream1));
241+ HIP_CHECK (hipStreamDestroy (stream2));
242+ HIP_CHECK (hipEventDestroy (forkStreamEvent));
243+ HIP_CHECK (hipEventDestroy (memsetEvent1));
244+ HIP_CHECK (hipEventDestroy (memsetEvent2));
245+ free (A_h);
246+ free (C_h);
247+ free (nodes);
248+ HIP_CHECK (hipFree (A_d));
249+ HIP_CHECK (hipFree (C_d));
250+ }
251+
252+
144253/* *
145254 * Test performs api parameter validation by passing various values
146255 * as input and output parameters and validates the behavior.
0 commit comments