@@ -189,7 +189,7 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
189189
190190 auto useEvents = std::get<1 >(GetParam ()).value ;
191191
192- std::vector<uur::raii::Event> Events (numOps * 2 );
192+ std::vector<uur::raii::Event> Events (numOps * 2 - 1 );
193193 for (size_t i = 0 ; i < numOps; i++) {
194194 size_t waitNum = 0 ;
195195 ur_event_handle_t *lastEvent = nullptr ;
@@ -202,7 +202,7 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
202202 lastEvent = i > 0 ? Events[i * 2 - 1 ].ptr () : nullptr ;
203203
204204 kernelEvent = Events[i * 2 ].ptr ();
205- memcpyEvent = Events[i * 2 + 1 ].ptr ();
205+ memcpyEvent = i < numOps - 1 ? Events[i * 2 + 1 ].ptr () : nullptr ;
206206 }
207207
208208 // execute kernel that increments each element by 1
@@ -220,9 +220,7 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
220220 }
221221
222222 if (useEvents) {
223- // TODO: just wait on the last event, once urEventWait is implemented
224- // by V2 L0 adapter
225- urQueueFinish (queue);
223+ urEventWait (1 , Events.back ().ptr ());
226224 } else {
227225 urQueueFinish (queue);
228226 }
@@ -237,12 +235,26 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
237235 }
238236}
239237
240- struct VoidParam {};
238+ template <typename T>
239+ inline std::string
240+ printBoolParam (const testing::TestParamInfo<typename T::ParamType> &info) {
241+ std::stringstream ss;
242+ ss << (info.param .value ? " " : " No" ) << info.param .name ;
243+ return ss.str ();
244+ }
245+
241246using urEnqueueKernelLaunchIncrementMultiDeviceTest =
242- urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<VoidParam>;
247+ urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<uur::BoolTestParam>;
248+
249+ INSTANTIATE_TEST_SUITE_P (
250+ , urEnqueueKernelLaunchIncrementMultiDeviceTest,
251+ testing::ValuesIn (uur::BoolTestParam::makeBoolParam(" UseEventWait" )),
252+ printBoolParam<urEnqueueKernelLaunchIncrementMultiDeviceTest>);
243253
244254// Do a chain of kernelLaunch(dev0) -> memcpy(dev0, dev1) -> kernelLaunch(dev1) ... ops
245- TEST_F (urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
255+ TEST_P (urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
256+ auto waitOnEvent = GetParam ().value ;
257+
246258 size_t returned_size;
247259 ASSERT_SUCCESS (urDeviceGetInfo (devices[0 ], UR_DEVICE_INFO_EXTENSIONS, 0 ,
248260 nullptr , &returned_size));
@@ -265,14 +277,15 @@ TEST_F(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
265277 constexpr size_t global_offset = 0 ;
266278 constexpr size_t n_dimensions = 1 ;
267279
268- std::vector<uur::raii::Event> Events (devices.size () * 2 );
280+ std::vector<uur::raii::Event> Events (devices.size () * 2 - 1 );
269281 for (size_t i = 0 ; i < devices.size (); i++) {
270282 // Events are: kernelEvent0, memcpyEvent0, kernelEvent1, ...
271283 size_t waitNum = i > 0 ? 1 : 0 ;
272284 ur_event_handle_t *lastEvent =
273285 i > 0 ? Events[i * 2 - 1 ].ptr () : nullptr ;
274286 ur_event_handle_t *kernelEvent = Events[i * 2 ].ptr ();
275- ur_event_handle_t *memcpyEvent = Events[i * 2 + 1 ].ptr ();
287+ ur_event_handle_t *memcpyEvent =
288+ i < devices.size () - 1 ? Events[i * 2 + 1 ].ptr () : nullptr ;
276289
277290 // execute kernel that increments each element by 1
278291 ASSERT_SUCCESS (urEnqueueKernelLaunch (
@@ -287,9 +300,13 @@ TEST_F(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
287300 }
288301 }
289302
290- // synchronize on the last queue only, this has to ensure all the operations
303+ // synchronize on the last queue/event only, this has to ensure all the operations
291304 // are completed
292- urQueueFinish (queues.back ());
305+ if (waitOnEvent) {
306+ urEventWait (1 , Events.back ().ptr ());
307+ } else {
308+ urQueueFinish (queues.back ());
309+ }
293310
294311 size_t ExpectedValue = InitialValue;
295312 for (size_t i = 0 ; i < devices.size (); i++) {
@@ -374,9 +391,11 @@ TEST_P(urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest, Success) {
374391 ArraySize * sizeof (uint32_t ), useEvents,
375392 lastEvent, signalEvent));
376393
377- urQueueFinish (queue);
378- // TODO: when useEvents is implemented for L0 v2 adapter
379- // wait on event instead
394+ if (useEvents) {
395+ urEventWait (1 , Events.back ().ptr ());
396+ } else {
397+ urQueueFinish (queue);
398+ }
380399
381400 size_t ExpectedValue = InitialValue;
382401 ExpectedValue += numOpsPerThread;
0 commit comments