1616#include < gtest/gtest.h>
1717
1818#include < sycl/sycl.hpp>
19+ #include < sycl/ext/oneapi/experimental/graph.hpp>
1920
2021using ::testing::HasSubstr;
2122using namespace sycl ;
2223XPTI_CALLBACK_API bool queryReceivedNotifications (uint16_t &TraceType,
2324 std::string &Message);
2425XPTI_CALLBACK_API void resetReceivedNotifications ();
2526XPTI_CALLBACK_API void addAnalyzedTraceType (uint16_t );
27+ XPTI_CALLBACK_API void clearAnalyzedTraceTypes ();
2628
2729class NodeCreation : public ::testing::Test {
2830protected:
@@ -34,6 +36,7 @@ class NodeCreation : public ::testing::Test {
3436
3537 void TearDown () {
3638 resetReceivedNotifications ();
39+ clearAnalyzedTraceTypes ();
3740 xptiForceSetTraceEnabled (false );
3841 }
3942
@@ -141,3 +144,95 @@ TEST_F(NodeCreation, QueueMemsetNode) {
141144 EXPECT_EQ (TraceType, xpti::trace_node_create);
142145 EXPECT_THAT (Message, HasSubstr (" memory_transfer_node" ));
143146}
147+
148+ TEST_F (NodeCreation, CommandGraphRecord) {
149+ sycl::queue Q;
150+ try {
151+ sycl::ext::oneapi::experimental::command_graph cmdGraph (Q.get_context (), Q.get_device ());
152+
153+ cmdGraph.begin_recording (Q);
154+
155+ {
156+ sycl::detail::tls_code_loc_t myLoc ({" LOCAL_CODELOC_FILE" , " LOCAL_CODELOC_NAME" , 1 , 1 });
157+ Q.submit (
158+ [&](handler &Cgh) {
159+ Cgh.parallel_for <TestKernel<KernelSize>>(1 , [=](sycl::id<1 > idx) {});
160+ });
161+ }
162+
163+ cmdGraph.end_recording (Q);
164+
165+ addAnalyzedTraceType (xpti::trace_task_begin);
166+ addAnalyzedTraceType (xpti::trace_task_end);
167+
168+ auto exeGraph = cmdGraph.finalize ();
169+
170+ // Notifications should have get generated during finalize
171+ //
172+ uint16_t TraceType = 0 ;
173+ std::string Message;
174+ ASSERT_TRUE (queryReceivedNotifications (TraceType, Message));
175+ EXPECT_EQ (TraceType, xpti::trace_node_create);
176+ EXPECT_THAT (Message, HasSubstr (" LOCAL_CODELOC_NAME" ));
177+
178+ ASSERT_TRUE (queryReceivedNotifications (TraceType, Message));
179+ EXPECT_EQ (TraceType, xpti::trace_task_begin);
180+
181+ ASSERT_TRUE (queryReceivedNotifications (TraceType, Message));
182+ EXPECT_EQ (TraceType, xpti::trace_task_end);
183+
184+ } catch (sycl::exception &e) {
185+ FAIL () << " sycl::exception what=" << e.what ();
186+ }
187+ }
188+
189+ TEST_F (NodeCreation, CommandGraphAddAPI) {
190+ sycl::queue Q;
191+ try {
192+ sycl::ext::oneapi::experimental::command_graph cmdGraph (Q.get_context (), Q.get_device ());
193+
194+ auto doAddNode = [&](const sycl::detail::code_location &loc) {
195+ sycl::detail::tls_code_loc_t codeLoc (loc);
196+ return cmdGraph.add (
197+ [&](handler &Cgh) {
198+ Cgh.parallel_for <TestKernel<KernelSize>>(1 , [=](sycl::id<1 > idx) {});
199+ });
200+ };
201+
202+ auto node1 = doAddNode ({" LOCAL_CODELOC_FILE" , " LOCAL_NODE_1" , 1 , 1 });
203+ auto node2 = doAddNode ({" LOCAL_CODELOC_FILE" , " LOCAL_NODE_2" , 2 , 1 });
204+ cmdGraph.make_edge (node1, node2);
205+
206+ addAnalyzedTraceType (xpti::trace_task_begin);
207+ addAnalyzedTraceType (xpti::trace_task_end);
208+
209+ auto exeGraph = cmdGraph.finalize ();
210+
211+ // Notifications should have get generated during finalize
212+ //
213+ uint16_t TraceType = 0 ;
214+ std::string Message;
215+ ASSERT_TRUE (queryReceivedNotifications (TraceType, Message));
216+ EXPECT_EQ (TraceType, xpti::trace_node_create);
217+ EXPECT_THAT (Message, HasSubstr (" LOCAL_NODE_1" ));
218+
219+ ASSERT_TRUE (queryReceivedNotifications (TraceType, Message));
220+ EXPECT_EQ (TraceType, xpti::trace_task_begin);
221+
222+ ASSERT_TRUE (queryReceivedNotifications (TraceType, Message));
223+ EXPECT_EQ (TraceType, xpti::trace_task_end);
224+
225+ ASSERT_TRUE (queryReceivedNotifications (TraceType, Message));
226+ EXPECT_EQ (TraceType, xpti::trace_node_create);
227+ EXPECT_THAT (Message, HasSubstr (" LOCAL_NODE_2" ));
228+
229+ ASSERT_TRUE (queryReceivedNotifications (TraceType, Message));
230+ EXPECT_EQ (TraceType, xpti::trace_task_begin);
231+
232+ ASSERT_TRUE (queryReceivedNotifications (TraceType, Message));
233+ EXPECT_EQ (TraceType, xpti::trace_task_end);
234+
235+ } catch (sycl::exception &e) {
236+ FAIL () << " sycl::exception what=" << e.what ();
237+ }
238+ }
0 commit comments