@@ -99,6 +99,29 @@ struct urMultiQueueLaunchMemcpyTest : uur::urMultiDeviceContextTestTemplate<1>,
9999 UUR_RETURN_ON_FATAL_FAILURE (
100100 uur::urMultiDeviceContextTestTemplate<1 >::TearDown ());
101101 }
102+
103+ void runBackgroundCheck (std::vector<uur::raii::Event> &Events) {
104+ std::vector<std::thread> threads;
105+ for (size_t i = 0 ; i < Events.size (); i++) {
106+ threads.emplace_back ([&, i] {
107+ ur_event_status_t status;
108+ do {
109+ ASSERT_SUCCESS (urEventGetInfo (
110+ Events[i].get (), UR_EVENT_INFO_COMMAND_EXECUTION_STATUS,
111+ sizeof (ur_event_status_t ), &status, nullptr ));
112+ } while (status != UR_EVENT_STATUS_COMPLETE);
113+
114+ auto ExpectedValue = InitialValue + i + 1 ;
115+ for (uint32_t j = 0 ; j < ArraySize; ++j) {
116+ ASSERT_EQ (reinterpret_cast <uint32_t *>(SharedMem[i])[j],
117+ ExpectedValue);
118+ }
119+ });
120+ }
121+ for (auto &thread : threads) {
122+ thread.join ();
123+ }
124+ }
102125};
103126
104127template <typename Param>
@@ -189,26 +212,24 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
189212
190213 auto useEvents = std::get<1 >(GetParam ()).value ;
191214
192- std::vector<uur::raii::Event> Events (numOps * 2 );
193- for (size_t i = 0 ; i < numOps; i++) {
194- size_t waitNum = 0 ;
195- ur_event_handle_t *lastEvent = nullptr ;
196- ur_event_handle_t *kernelEvent = nullptr ;
197- ur_event_handle_t *memcpyEvent = nullptr ;
215+ std::vector<uur::raii::Event> kernelEvents (numOps);
216+ std::vector<uur::raii::Event> memcpyEvents (numOps - 1 );
198217
199- if (useEvents) {
200- // Events are: kernelEvent0, memcpyEvent0, kernelEvent1, ...
201- waitNum = i > 0 ? 1 : 0 ;
202- lastEvent = i > 0 ? Events[i * 2 - 1 ].ptr () : nullptr ;
218+ ur_event_handle_t *lastMemcpyEvent = nullptr ;
219+ ur_event_handle_t *kernelEvent = nullptr ;
220+ ur_event_handle_t *memcpyEvent = nullptr ;
203221
204- kernelEvent = Events[i * 2 ].ptr ();
205- memcpyEvent = Events[i * 2 + 1 ].ptr ();
222+ for (size_t i = 0 ; i < numOps; i++) {
223+ if (useEvents) {
224+ lastMemcpyEvent = memcpyEvent;
225+ kernelEvent = kernelEvents[i].ptr ();
226+ memcpyEvent = i < numOps - 1 ? memcpyEvents[i].ptr () : nullptr ;
206227 }
207228
208229 // execute kernel that increments each element by 1
209230 ASSERT_SUCCESS (urEnqueueKernelLaunch (
210231 queue, kernels[i], n_dimensions, &global_offset, &ArraySize,
211- nullptr , waitNum, lastEvent , kernelEvent));
232+ nullptr , bool (lastMemcpyEvent), lastMemcpyEvent , kernelEvent));
212233
213234 // copy the memory (input for the next kernel)
214235 if (i < numOps - 1 ) {
@@ -220,11 +241,9 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
220241 }
221242
222243 if (useEvents) {
223- // TODO: just wait on the last event, once urEventWait is implemented
224- // by V2 L0 adapter
225- urQueueFinish (queue);
244+ ASSERT_SUCCESS (urEventWait (1 , kernelEvents.back ().ptr ()));
226245 } else {
227- urQueueFinish (queue);
246+ ASSERT_SUCCESS ( urQueueFinish (queue) );
228247 }
229248
230249 size_t ExpectedValue = InitialValue;
@@ -237,12 +256,41 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
237256 }
238257}
239258
240- struct VoidParam {};
259+ template <typename T>
260+ inline std::string
261+ printParams (const testing::TestParamInfo<typename T::ParamType> &info) {
262+ std::stringstream ss;
263+
264+ auto param1 = std::get<0 >(info.param );
265+ ss << (param1.value ? " " : " No" ) << param1.name ;
266+
267+ auto param2 = std::get<1 >(info.param );
268+ ss << (param2.value ? " " : " No" ) << param2.name ;
269+
270+ if constexpr (std::tuple_size_v < typename T::ParamType >> 2 ) {
271+ auto param3 = std::get<2 >(info.param );
272+ }
273+
274+ return ss.str ();
275+ }
276+
241277using urEnqueueKernelLaunchIncrementMultiDeviceTest =
242- urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<VoidParam>;
278+ urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<
279+ std::tuple<uur::BoolTestParam, uur::BoolTestParam>>;
280+
281+ INSTANTIATE_TEST_SUITE_P (
282+ , urEnqueueKernelLaunchIncrementMultiDeviceTest,
283+ testing::Combine (
284+ testing::ValuesIn (uur::BoolTestParam::makeBoolParam(" UseEventWait" )),
285+ testing::ValuesIn(
286+ uur::BoolTestParam::makeBoolParam (" RunBackgroundCheck" ))),
287+ printParams<urEnqueueKernelLaunchIncrementMultiDeviceTest>);
243288
244289// Do a chain of kernelLaunch(dev0) -> memcpy(dev0, dev1) -> kernelLaunch(dev1) ... ops
245- TEST_F (urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
290+ TEST_P (urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
291+ auto waitOnEvent = std::get<0 >(GetParam ()).value ;
292+ auto runBackgroundCheck = std::get<1 >(GetParam ()).value ;
293+
246294 size_t returned_size;
247295 ASSERT_SUCCESS (urDeviceGetInfo (devices[0 ], UR_DEVICE_INFO_EXTENSIONS, 0 ,
248296 nullptr , &returned_size));
@@ -265,19 +313,22 @@ TEST_F(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
265313 constexpr size_t global_offset = 0 ;
266314 constexpr size_t n_dimensions = 1 ;
267315
268- std::vector<uur::raii::Event> Events (devices.size () * 2 );
316+ std::vector<uur::raii::Event> kernelEvents (devices.size ());
317+ std::vector<uur::raii::Event> memcpyEvents (devices.size () - 1 );
318+
319+ ur_event_handle_t *lastMemcpyEvent = nullptr ;
320+ ur_event_handle_t *kernelEvent = nullptr ;
321+ ur_event_handle_t *memcpyEvent = nullptr ;
322+
269323 for (size_t i = 0 ; i < devices.size (); i++) {
270- // Events are: kernelEvent0, memcpyEvent0, kernelEvent1, ...
271- size_t waitNum = i > 0 ? 1 : 0 ;
272- ur_event_handle_t *lastEvent =
273- i > 0 ? Events[i * 2 - 1 ].ptr () : nullptr ;
274- ur_event_handle_t *kernelEvent = Events[i * 2 ].ptr ();
275- ur_event_handle_t *memcpyEvent = Events[i * 2 + 1 ].ptr ();
324+ lastMemcpyEvent = memcpyEvent;
325+ kernelEvent = kernelEvents[i].ptr ();
326+ memcpyEvent = i < devices.size () - 1 ? memcpyEvents[i].ptr () : nullptr ;
276327
277328 // execute kernel that increments each element by 1
278329 ASSERT_SUCCESS (urEnqueueKernelLaunch (
279330 queues[i], kernels[i], n_dimensions, &global_offset, &ArraySize,
280- nullptr , waitNum, lastEvent , kernelEvent));
331+ nullptr , bool (lastMemcpyEvent), lastMemcpyEvent , kernelEvent));
281332
282333 // copy the memory to next device
283334 if (i < devices.size () - 1 ) {
@@ -287,9 +338,18 @@ TEST_F(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
287338 }
288339 }
289340
290- // synchronize on the last queue only, this has to ensure all the operations
341+ // While the device(s) execute, loop over the events and if completed, verify the results
342+ if (runBackgroundCheck) {
343+ this ->runBackgroundCheck (kernelEvents);
344+ }
345+
346+ // synchronize on the last queue/event only, this has to ensure all the operations
291347 // are completed
292- urQueueFinish (queues.back ());
348+ if (waitOnEvent) {
349+ ASSERT_SUCCESS (urEventWait (1 , kernelEvents.back ().ptr ()));
350+ } else {
351+ ASSERT_SUCCESS (urQueueFinish (queues.back ()));
352+ }
293353
294354 size_t ExpectedValue = InitialValue;
295355 for (size_t i = 0 ; i < devices.size (); i++) {
@@ -301,20 +361,6 @@ TEST_F(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
301361 }
302362}
303363
304- template <typename T>
305- inline std::string
306- printParams (const testing::TestParamInfo<typename T::ParamType> &info) {
307- std::stringstream ss;
308-
309- auto param1 = std::get<0 >(info.param );
310- auto param2 = std::get<1 >(info.param );
311-
312- ss << (param1.value ? " " : " No" ) << param1.name ;
313- ss << (param2.value ? " " : " No" ) << param2.name ;
314-
315- return ss.str ();
316- }
317-
318364using urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest =
319365 urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<
320366 std::tuple<uur::BoolTestParam, uur::BoolTestParam>>;
@@ -374,9 +420,11 @@ TEST_P(urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest, Success) {
374420 ArraySize * sizeof (uint32_t ), useEvents,
375421 lastEvent, signalEvent));
376422
377- urQueueFinish (queue);
378- // TODO: when useEvents is implemented for L0 v2 adapter
379- // wait on event instead
423+ if (useEvents) {
424+ ASSERT_SUCCESS (urEventWait (1 , Events.back ().ptr ()));
425+ } else {
426+ ASSERT_SUCCESS (urQueueFinish (queue));
427+ }
380428
381429 size_t ExpectedValue = InitialValue;
382430 ExpectedValue += numOpsPerThread;
0 commit comments