Skip to content

Commit 60a07fd

Browse files
authored
SWDEV-306122 - [catch2][dtest] Additional tests for hipGraphGetNodes/hipGraphGetRootNodes api (#2860)
Change-Id: I4e9a6eadd92b6de4aded0aa96464bfcc59441ac4
1 parent c108e59 commit 60a07fd

File tree

2 files changed

+215
-0
lines changed

2 files changed

+215
-0
lines changed

tests/catch/unit/graph/hipGraphGetNodes.cc

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Functional ::
2525
2) Pass nodes as nullptr and verify numNodes returns actual number of nodes added to graph.
2626
3) If numNodes passed is greater than the actual number of nodes, the remaining entries in nodes
2727
will be set to NULL, and the number of nodes actually obtained will be returned in numNodes.
28+
4) Begin stream capture and push operations to stream. Verify nodes of created graph are matching the
29+
operations pushed.
2830
2931
Argument Validation ::
3032
1) Pass graph as nullptr and verify api returns error code.
@@ -139,6 +141,110 @@ TEST_CASE("Unit_hipGraphGetNodes_Functional") {
139141
free(nodes);
140142
}
141143

144+
/**
145+
* Begin stream capture and push operations to stream.
146+
* Verify nodes of created graph are matching the operations pushed.
147+
*/
148+
TEST_CASE("Unit_hipGraphGetNodes_CapturedStream") {
149+
hipGraph_t graph{nullptr};
150+
hipGraphExec_t graphExec{nullptr};
151+
constexpr unsigned blocks = 512;
152+
constexpr unsigned threadsPerBlock = 256;
153+
constexpr size_t N = 1000000;
154+
size_t Nbytes = N * sizeof(float);
155+
constexpr int numMemcpy{2}, numKernel{1}, numMemset{1};
156+
int cntMemcpy{}, cntKernel{}, cntMemset{};
157+
hipStream_t stream, streamForGraph;
158+
hipGraphNodeType nodeType;
159+
float *A_d, *C_d;
160+
float *A_h, *C_h;
161+
162+
A_h = reinterpret_cast<float*>(malloc(Nbytes));
163+
C_h = reinterpret_cast<float*>(malloc(Nbytes));
164+
REQUIRE(A_h != nullptr);
165+
REQUIRE(C_h != nullptr);
166+
HIP_CHECK(hipMalloc(&A_d, Nbytes));
167+
HIP_CHECK(hipMalloc(&C_d, Nbytes));
168+
REQUIRE(A_d != nullptr);
169+
REQUIRE(C_d != nullptr);
170+
171+
HIP_CHECK(hipStreamCreate(&streamForGraph));
172+
// Initialize input buffer
173+
for (size_t i = 0; i < N; ++i) {
174+
A_h[i] = 3.146f + i; // Pi
175+
}
176+
177+
HIP_CHECK(hipStreamCreate(&stream));
178+
HIP_CHECK(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
179+
HIP_CHECK(hipMemcpyAsync(A_d, A_h, Nbytes, hipMemcpyHostToDevice, stream));
180+
HIP_CHECK(hipMemsetAsync(C_d, 0, Nbytes, stream));
181+
hipLaunchKernelGGL(HipTest::vector_square, dim3(blocks),
182+
dim3(threadsPerBlock), 0, stream, A_d, C_d, N);
183+
HIP_CHECK(hipMemcpyAsync(C_h, C_d, Nbytes, hipMemcpyDeviceToHost, stream));
184+
HIP_CHECK(hipStreamEndCapture(stream, &graph));
185+
REQUIRE(graph != nullptr);
186+
187+
size_t numNodes{};
188+
HIP_CHECK(hipGraphGetNodes(graph, nullptr, &numNodes));
189+
INFO("Num of nodes returned by GetNodes : " << numNodes);
190+
REQUIRE(numNodes == numMemcpy + numKernel + numMemset);
191+
192+
int numBytes = sizeof(hipGraphNode_t) * numNodes;
193+
hipGraphNode_t* nodes = reinterpret_cast<hipGraphNode_t *>(malloc(numBytes));
194+
REQUIRE(nodes != nullptr);
195+
196+
HIP_CHECK(hipGraphGetNodes(graph, nodes, &numNodes));
197+
for (size_t i = 0; i < numNodes; i++) {
198+
HIP_CHECK(hipGraphNodeGetType(nodes[i], &nodeType));
199+
200+
switch (nodeType) {
201+
case hipGraphNodeTypeMemcpy:
202+
cntMemcpy++;
203+
break;
204+
205+
case hipGraphNodeTypeKernel:
206+
cntKernel++;
207+
break;
208+
209+
case hipGraphNodeTypeMemset:
210+
cntMemset++;
211+
break;
212+
213+
default:
214+
INFO("Unexpected nodetype returned : " << nodeType);
215+
REQUIRE(false);
216+
}
217+
}
218+
219+
REQUIRE(cntMemcpy == numMemcpy);
220+
REQUIRE(cntKernel == numKernel);
221+
REQUIRE(cntMemset == numMemset);
222+
223+
// Instantiate and launch the graph
224+
HIP_CHECK(hipGraphInstantiate(&graphExec, graph, NULL, NULL, 0));
225+
HIP_CHECK(hipGraphLaunch(graphExec, streamForGraph));
226+
HIP_CHECK(hipStreamSynchronize(streamForGraph));
227+
228+
// Validate the computation
229+
for (size_t i = 0; i < N; i++) {
230+
if (C_h[i] != A_h[i] * A_h[i]) {
231+
INFO("A and C not matching at " << i << " C_h[i] " << C_h[i]
232+
<< " A_h[i] " << A_h[i]);
233+
REQUIRE(false);
234+
}
235+
}
236+
237+
HIP_CHECK(hipStreamDestroy(streamForGraph));
238+
HIP_CHECK(hipStreamDestroy(stream));
239+
HIP_CHECK(hipGraphExecDestroy(graphExec));
240+
HIP_CHECK(hipGraphDestroy(graph));
241+
free(A_h);
242+
free(C_h);
243+
free(nodes);
244+
HIP_CHECK(hipFree(A_d));
245+
HIP_CHECK(hipFree(C_d));
246+
}
247+
142248
/**
143249
* Test performs api parameter validation by passing various values
144250
* as input and output parameters and validates the behavior.

tests/catch/unit/graph/hipGraphGetRootNodes.cc

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Functional ::
2626
2) Pass nodes as nullptr and verify api returns actual number of root nodes added to graph.
2727
3) If NumRootNodes passed is greater than the actual number of root nodes, the remaining entries in
2828
nodes list will be set to NULL, and the number of nodes actually obtained will be returned in NumRootNodes.
29+
4) Create a graph with stream capture done on multiple dependent streams.
30+
Verify root nodes of created graph are matching the operations pushed which doesn't have dependencies.
2931
3032
Argument Validation ::
3133
1) Pass graph as nullptr and verify api returns error code.
@@ -50,6 +52,7 @@ TEST_CASE("Unit_hipGraphGetRootNodes_Functional") {
5052
constexpr auto addlEntries = 5;
5153
hipGraph_t graph;
5254

55+
5356
hipGraphNode_t memcpyNode, kernelNode;
5457
hipKernelNodeParams kernelNodeParams{};
5558
hipStream_t streamForGraph;
@@ -141,6 +144,112 @@ TEST_CASE("Unit_hipGraphGetRootNodes_Functional") {
141144
free(rootnodes);
142145
}
143146

147+
148+
/**
149+
* Create a graph with stream capture done on multiple dependent streams. Verify root nodes
150+
* of created graph are matching the operations pushed which doesn't have dependencies.
151+
*/
152+
TEST_CASE("Unit_hipGraphGetRootNodes_CapturedStream") {
153+
hipStream_t stream1{nullptr}, stream2{nullptr}, mstream{nullptr};
154+
hipStream_t streamForGraph{nullptr};
155+
hipEvent_t memsetEvent1, memsetEvent2, forkStreamEvent;
156+
hipGraph_t graph{nullptr};
157+
hipGraphExec_t graphExec{nullptr};
158+
constexpr unsigned blocks = 512;
159+
constexpr unsigned threadsPerBlock = 256;
160+
constexpr size_t N = 1000000;
161+
constexpr int numMemsetNodes = 2;
162+
size_t Nbytes = N * sizeof(float), numRootNodes{};
163+
float *A_d, *C_d;
164+
float *A_h, *C_h;
165+
A_h = reinterpret_cast<float*>(malloc(Nbytes));
166+
C_h = reinterpret_cast<float*>(malloc(Nbytes));
167+
REQUIRE(A_h != nullptr);
168+
REQUIRE(C_h != nullptr);
169+
HIP_CHECK(hipMalloc(&A_d, Nbytes));
170+
HIP_CHECK(hipMalloc(&C_d, Nbytes));
171+
REQUIRE(A_d != nullptr);
172+
REQUIRE(C_d != nullptr);
173+
174+
HIP_CHECK(hipStreamCreate(&streamForGraph));
175+
176+
// Initialize input buffer
177+
for (size_t i = 0; i < N; ++i) {
178+
A_h[i] = 3.146f + i; // Pi
179+
}
180+
181+
HIP_CHECK(hipStreamCreate(&stream1));
182+
HIP_CHECK(hipStreamCreate(&stream2));
183+
HIP_CHECK(hipStreamCreate(&mstream));
184+
HIP_CHECK(hipEventCreate(&memsetEvent1));
185+
HIP_CHECK(hipEventCreate(&memsetEvent2));
186+
HIP_CHECK(hipEventCreate(&forkStreamEvent));
187+
HIP_CHECK(hipStreamBeginCapture(mstream, hipStreamCaptureModeGlobal));
188+
HIP_CHECK(hipEventRecord(forkStreamEvent, mstream));
189+
HIP_CHECK(hipStreamWaitEvent(stream1, forkStreamEvent, 0));
190+
HIP_CHECK(hipStreamWaitEvent(stream2, forkStreamEvent, 0));
191+
HIP_CHECK(hipMemsetAsync(A_d, 0, Nbytes, stream1));
192+
HIP_CHECK(hipEventRecord(memsetEvent1, stream1));
193+
HIP_CHECK(hipMemsetAsync(C_d, 0, Nbytes, stream2));
194+
HIP_CHECK(hipEventRecord(memsetEvent2, stream2));
195+
HIP_CHECK(hipStreamWaitEvent(mstream, memsetEvent1, 0));
196+
HIP_CHECK(hipStreamWaitEvent(mstream, memsetEvent2, 0));
197+
HIP_CHECK(hipMemcpyAsync(A_d, A_h, Nbytes, hipMemcpyHostToDevice, mstream));
198+
hipLaunchKernelGGL(HipTest::vector_square, dim3(blocks),
199+
dim3(threadsPerBlock), 0, mstream, A_d, C_d, N);
200+
HIP_CHECK(hipMemcpyAsync(C_h, C_d, Nbytes, hipMemcpyDeviceToHost, mstream));
201+
HIP_CHECK(hipStreamEndCapture(mstream, &graph));
202+
203+
// Verify numof root nodes
204+
HIP_CHECK(hipGraphGetRootNodes(graph, nullptr, &numRootNodes));
205+
REQUIRE(numRootNodes == numMemsetNodes);
206+
INFO("Num of nodes returned by GetRootNodes : " << numRootNodes);
207+
208+
int numBytes = sizeof(hipGraphNode_t) * numRootNodes;
209+
hipGraphNode_t* nodes = reinterpret_cast<hipGraphNode_t *>(malloc(numBytes));
210+
REQUIRE(nodes != nullptr);
211+
212+
hipGraphNodeType nodeType;
213+
HIP_CHECK(hipGraphGetRootNodes(graph, nodes, &numRootNodes));
214+
REQUIRE(numRootNodes == numMemsetNodes);
215+
216+
// Verify root nodes returned are memset nodes.
217+
HIP_CHECK(hipGraphNodeGetType(nodes[0], &nodeType));
218+
REQUIRE(nodeType == hipGraphNodeTypeMemset);
219+
HIP_CHECK(hipGraphNodeGetType(nodes[1], &nodeType));
220+
REQUIRE(nodeType == hipGraphNodeTypeMemset);
221+
222+
// Instantiate and launch the graph
223+
HIP_CHECK(hipGraphInstantiate(&graphExec, graph, NULL, NULL, 0));
224+
HIP_CHECK(hipGraphLaunch(graphExec, streamForGraph));
225+
HIP_CHECK(hipStreamSynchronize(streamForGraph));
226+
227+
// Validate the computation
228+
for (size_t i = 0; i < N; i++) {
229+
if (C_h[i] != A_h[i] * A_h[i]) {
230+
INFO("A and C not matching at " << i << " C_h[i] " << C_h[i]
231+
<< " A_h[i] " << A_h[i]);
232+
REQUIRE(false);
233+
}
234+
}
235+
236+
HIP_CHECK(hipGraphExecDestroy(graphExec));
237+
HIP_CHECK(hipGraphDestroy(graph));
238+
HIP_CHECK(hipStreamDestroy(streamForGraph));
239+
HIP_CHECK(hipStreamDestroy(mstream));
240+
HIP_CHECK(hipStreamDestroy(stream1));
241+
HIP_CHECK(hipStreamDestroy(stream2));
242+
HIP_CHECK(hipEventDestroy(forkStreamEvent));
243+
HIP_CHECK(hipEventDestroy(memsetEvent1));
244+
HIP_CHECK(hipEventDestroy(memsetEvent2));
245+
free(A_h);
246+
free(C_h);
247+
free(nodes);
248+
HIP_CHECK(hipFree(A_d));
249+
HIP_CHECK(hipFree(C_d));
250+
}
251+
252+
144253
/**
145254
* Test performs api parameter validation by passing various values
146255
* as input and output parameters and validates the behavior.

0 commit comments

Comments
 (0)