Skip to content

Commit dd3ed98

Browse files
committed
Added two unit tests to check notification generation for command graph API.
Signed-off-by: Guy Zadicario <[email protected]>
1 parent ab29f52 commit dd3ed98

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

sycl/unittests/xpti_trace/NodeCreation.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
#include <gtest/gtest.h>
1717

1818
#include <sycl/sycl.hpp>
19+
#include <sycl/ext/oneapi/experimental/graph.hpp>
1920

2021
using ::testing::HasSubstr;
2122
using namespace sycl;
2223
XPTI_CALLBACK_API bool queryReceivedNotifications(uint16_t &TraceType,
2324
std::string &Message);
2425
XPTI_CALLBACK_API void resetReceivedNotifications();
2526
XPTI_CALLBACK_API void addAnalyzedTraceType(uint16_t);
27+
XPTI_CALLBACK_API void clearAnalyzedTraceTypes();
2628

2729
class NodeCreation : public ::testing::Test {
2830
protected:
@@ -34,6 +36,7 @@ class NodeCreation : public ::testing::Test {
3436

3537
void TearDown() {
3638
resetReceivedNotifications();
39+
clearAnalyzedTraceTypes();
3740
xptiForceSetTraceEnabled(false);
3841
}
3942

@@ -141,3 +144,95 @@ TEST_F(NodeCreation, QueueMemsetNode) {
141144
EXPECT_EQ(TraceType, xpti::trace_node_create);
142145
EXPECT_THAT(Message, HasSubstr("memory_transfer_node"));
143146
}
147+
148+
TEST_F(NodeCreation, CommandGraphRecord) {
149+
sycl::queue Q;
150+
try {
151+
sycl::ext::oneapi::experimental::command_graph cmdGraph(Q.get_context(), Q.get_device());
152+
153+
cmdGraph.begin_recording(Q);
154+
155+
{
156+
sycl::detail::tls_code_loc_t myLoc({"LOCAL_CODELOC_FILE", "LOCAL_CODELOC_NAME", 1, 1});
157+
Q.submit(
158+
[&](handler &Cgh) {
159+
Cgh.parallel_for<TestKernel<KernelSize>>(1, [=](sycl::id<1> idx) {});
160+
});
161+
}
162+
163+
cmdGraph.end_recording(Q);
164+
165+
addAnalyzedTraceType(xpti::trace_task_begin);
166+
addAnalyzedTraceType(xpti::trace_task_end);
167+
168+
auto exeGraph = cmdGraph.finalize();
169+
170+
// Notifications should have get generated during finalize
171+
//
172+
uint16_t TraceType = 0;
173+
std::string Message;
174+
ASSERT_TRUE(queryReceivedNotifications(TraceType, Message));
175+
EXPECT_EQ(TraceType, xpti::trace_node_create);
176+
EXPECT_THAT(Message, HasSubstr("LOCAL_CODELOC_NAME"));
177+
178+
ASSERT_TRUE(queryReceivedNotifications(TraceType, Message));
179+
EXPECT_EQ(TraceType, xpti::trace_task_begin);
180+
181+
ASSERT_TRUE(queryReceivedNotifications(TraceType, Message));
182+
EXPECT_EQ(TraceType, xpti::trace_task_end);
183+
184+
} catch (sycl::exception &e) {
185+
FAIL() << "sycl::exception what=" << e.what();
186+
}
187+
}
188+
189+
TEST_F(NodeCreation, CommandGraphAddAPI) {
190+
sycl::queue Q;
191+
try {
192+
sycl::ext::oneapi::experimental::command_graph cmdGraph(Q.get_context(), Q.get_device());
193+
194+
auto doAddNode = [&](const sycl::detail::code_location &loc) {
195+
sycl::detail::tls_code_loc_t codeLoc(loc);
196+
return cmdGraph.add(
197+
[&](handler &Cgh) {
198+
Cgh.parallel_for<TestKernel<KernelSize>>(1, [=](sycl::id<1> idx) {});
199+
});
200+
};
201+
202+
auto node1 = doAddNode({"LOCAL_CODELOC_FILE", "LOCAL_NODE_1", 1, 1});
203+
auto node2 = doAddNode({"LOCAL_CODELOC_FILE", "LOCAL_NODE_2", 2, 1});
204+
cmdGraph.make_edge(node1, node2);
205+
206+
addAnalyzedTraceType(xpti::trace_task_begin);
207+
addAnalyzedTraceType(xpti::trace_task_end);
208+
209+
auto exeGraph = cmdGraph.finalize();
210+
211+
// Notifications should have get generated during finalize
212+
//
213+
uint16_t TraceType = 0;
214+
std::string Message;
215+
ASSERT_TRUE(queryReceivedNotifications(TraceType, Message));
216+
EXPECT_EQ(TraceType, xpti::trace_node_create);
217+
EXPECT_THAT(Message, HasSubstr("LOCAL_NODE_1"));
218+
219+
ASSERT_TRUE(queryReceivedNotifications(TraceType, Message));
220+
EXPECT_EQ(TraceType, xpti::trace_task_begin);
221+
222+
ASSERT_TRUE(queryReceivedNotifications(TraceType, Message));
223+
EXPECT_EQ(TraceType, xpti::trace_task_end);
224+
225+
ASSERT_TRUE(queryReceivedNotifications(TraceType, Message));
226+
EXPECT_EQ(TraceType, xpti::trace_node_create);
227+
EXPECT_THAT(Message, HasSubstr("LOCAL_NODE_2"));
228+
229+
ASSERT_TRUE(queryReceivedNotifications(TraceType, Message));
230+
EXPECT_EQ(TraceType, xpti::trace_task_begin);
231+
232+
ASSERT_TRUE(queryReceivedNotifications(TraceType, Message));
233+
EXPECT_EQ(TraceType, xpti::trace_task_end);
234+
235+
} catch (sycl::exception &e) {
236+
FAIL() << "sycl::exception what=" << e.what();
237+
}
238+
}

0 commit comments

Comments
 (0)