@@ -189,7 +189,7 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
189
189
190
190
auto useEvents = std::get<1 >(GetParam ()).value ;
191
191
192
- std::vector<uur::raii::Event> Events (numOps * 2 );
192
+ std::vector<uur::raii::Event> Events (numOps * 2 - 1 );
193
193
for (size_t i = 0 ; i < numOps; i++) {
194
194
size_t waitNum = 0 ;
195
195
ur_event_handle_t *lastEvent = nullptr ;
@@ -202,7 +202,7 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
202
202
lastEvent = i > 0 ? Events[i * 2 - 1 ].ptr () : nullptr ;
203
203
204
204
kernelEvent = Events[i * 2 ].ptr ();
205
- memcpyEvent = Events[i * 2 + 1 ].ptr ();
205
+ memcpyEvent = i < numOps - 1 ? Events[i * 2 + 1 ].ptr () : nullptr ;
206
206
}
207
207
208
208
// execute kernel that increments each element by 1
@@ -220,9 +220,7 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
220
220
}
221
221
222
222
if (useEvents) {
223
- // TODO: just wait on the last event, once urEventWait is implemented
224
- // by V2 L0 adapter
225
- urQueueFinish (queue);
223
+ urEventWait (1 , Events.back ().ptr ());
226
224
} else {
227
225
urQueueFinish (queue);
228
226
}
@@ -237,12 +235,26 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
237
235
}
238
236
}
239
237
240
- struct VoidParam {};
238
+ template <typename T>
239
+ inline std::string
240
+ printBoolParam (const testing::TestParamInfo<typename T::ParamType> &info) {
241
+ std::stringstream ss;
242
+ ss << (info.param .value ? " " : " No" ) << info.param .name ;
243
+ return ss.str ();
244
+ }
245
+
241
246
using urEnqueueKernelLaunchIncrementMultiDeviceTest =
242
- urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<VoidParam>;
247
+ urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<uur::BoolTestParam>;
248
+
249
+ INSTANTIATE_TEST_SUITE_P (
250
+ , urEnqueueKernelLaunchIncrementMultiDeviceTest,
251
+ testing::ValuesIn (uur::BoolTestParam::makeBoolParam(" UseEventWait" )),
252
+ printBoolParam<urEnqueueKernelLaunchIncrementMultiDeviceTest>);
243
253
244
254
// Do a chain of kernelLaunch(dev0) -> memcpy(dev0, dev1) -> kernelLaunch(dev1) ... ops
245
- TEST_F (urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
255
+ TEST_P (urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
256
+ auto waitOnEvent = GetParam ().value ;
257
+
246
258
size_t returned_size;
247
259
ASSERT_SUCCESS (urDeviceGetInfo (devices[0 ], UR_DEVICE_INFO_EXTENSIONS, 0 ,
248
260
nullptr , &returned_size));
@@ -265,14 +277,15 @@ TEST_F(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
265
277
constexpr size_t global_offset = 0 ;
266
278
constexpr size_t n_dimensions = 1 ;
267
279
268
- std::vector<uur::raii::Event> Events (devices.size () * 2 );
280
+ std::vector<uur::raii::Event> Events (devices.size () * 2 - 1 );
269
281
for (size_t i = 0 ; i < devices.size (); i++) {
270
282
// Events are: kernelEvent0, memcpyEvent0, kernelEvent1, ...
271
283
size_t waitNum = i > 0 ? 1 : 0 ;
272
284
ur_event_handle_t *lastEvent =
273
285
i > 0 ? Events[i * 2 - 1 ].ptr () : nullptr ;
274
286
ur_event_handle_t *kernelEvent = Events[i * 2 ].ptr ();
275
- ur_event_handle_t *memcpyEvent = Events[i * 2 + 1 ].ptr ();
287
+ ur_event_handle_t *memcpyEvent =
288
+ i < devices.size () - 1 ? Events[i * 2 + 1 ].ptr () : nullptr ;
276
289
277
290
// execute kernel that increments each element by 1
278
291
ASSERT_SUCCESS (urEnqueueKernelLaunch (
@@ -287,9 +300,13 @@ TEST_F(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
287
300
}
288
301
}
289
302
290
- // synchronize on the last queue only, this has to ensure all the operations
303
+ // synchronize on the last queue/event only, this has to ensure all the operations
291
304
// are completed
292
- urQueueFinish (queues.back ());
305
+ if (waitOnEvent) {
306
+ urEventWait (1 , Events.back ().ptr ());
307
+ } else {
308
+ urQueueFinish (queues.back ());
309
+ }
293
310
294
311
size_t ExpectedValue = InitialValue;
295
312
for (size_t i = 0 ; i < devices.size (); i++) {
@@ -374,9 +391,11 @@ TEST_P(urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest, Success) {
374
391
ArraySize * sizeof (uint32_t ), useEvents,
375
392
lastEvent, signalEvent));
376
393
377
- urQueueFinish (queue);
378
- // TODO: when useEvents is implemented for L0 v2 adapter
379
- // wait on event instead
394
+ if (useEvents) {
395
+ urEventWait (1 , Events.back ().ptr ());
396
+ } else {
397
+ urQueueFinish (queue);
398
+ }
380
399
381
400
size_t ExpectedValue = InitialValue;
382
401
ExpectedValue += numOpsPerThread;
0 commit comments