33// See LICENSE.TXT
44// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55
6+ #include " helpers.h"
7+
68#include < uur/fixtures.h>
79#include < uur/raii.h>
810
1315std::tuple<size_t , size_t , size_t > minL0DriverVersion = {1 , 3 , 29534 };
1416
1517template <typename T>
16- struct urMultiQueueLaunchMemcpyTest : uur::urMultiDeviceContextTestTemplate< 1 > ,
18+ struct urMultiQueueLaunchMemcpyTest : uur::urMultiQueueMultiDeviceTest ,
1719 testing::WithParamInterface<T> {
1820 std::string KernelName;
1921 std::vector<ur_program_handle_t > programs;
2022 std::vector<ur_kernel_handle_t > kernels;
2123 std::vector<void *> SharedMem;
2224
23- std::vector<ur_queue_handle_t > queues;
24- std::vector<ur_device_handle_t > devices;
25-
26- std::function<void (void )> createQueues;
27-
2825 static constexpr char ProgramName[] = " increment" ;
2926 static constexpr size_t ArraySize = 100 ;
3027 static constexpr size_t InitialValue = 1 ;
3128
32- void SetUp () override {
33- UUR_RETURN_ON_FATAL_FAILURE (
34- uur::urMultiDeviceContextTestTemplate<1 >::SetUp ());
29+ void SetUp () override { throw std::runtime_error (" Not implemented" ); }
3530
36- createQueues ();
31+ void SetUp (std::vector<ur_device_handle_t > srcDevices,
32+ size_t duplicateDevices) {
33+ UUR_RETURN_ON_FATAL_FAILURE (uur::urMultiQueueMultiDeviceTest::SetUp (
34+ srcDevices, duplicateDevices));
3735
3836 for (auto &device : devices) {
3937 SKIP_IF_DRIVER_TOO_OLD (" Level-Zero" , minL0DriverVersion, platform,
@@ -87,9 +85,6 @@ struct urMultiQueueLaunchMemcpyTest : uur::urMultiDeviceContextTestTemplate<1>,
8785 for (auto &Ptr : SharedMem) {
8886 urUSMFree (context, Ptr);
8987 }
90- for (const auto &queue : queues) {
91- EXPECT_SUCCESS (urQueueRelease (queue));
92- }
9388 for (const auto &kernel : kernels) {
9489 urKernelRelease (kernel);
9590 }
@@ -136,23 +131,8 @@ struct urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam
136131 using urMultiQueueLaunchMemcpyTest<Param>::SharedMem;
137132
138133 void SetUp () override {
139- this ->createQueues = [&] {
140- for (size_t i = 0 ; i < duplicateDevices; i++) {
141- devices.insert (
142- devices.end (),
143- uur::KernelsEnvironment::instance->devices .begin (),
144- uur::KernelsEnvironment::instance->devices .end ());
145- }
146-
147- for (auto &device : devices) {
148- ur_queue_handle_t queue = nullptr ;
149- ASSERT_SUCCESS (urQueueCreate (context, device, 0 , &queue));
150- queues.push_back (queue);
151- }
152- };
153-
154- UUR_RETURN_ON_FATAL_FAILURE (
155- urMultiQueueLaunchMemcpyTest<Param>::SetUp ());
134+ UUR_RETURN_ON_FATAL_FAILURE (urMultiQueueLaunchMemcpyTest<Param>::SetUp (
135+ uur::KernelsEnvironment::instance->devices , duplicateDevices));
156136 }
157137
158138 void TearDown () override {
@@ -166,8 +146,6 @@ struct urEnqueueKernelLaunchIncrementTest
166146 std::tuple<ur_device_handle_t , uur::BoolTestParam>> {
167147 static constexpr size_t numOps = 50 ;
168148
169- ur_queue_handle_t queue;
170-
171149 using Param = std::tuple<ur_device_handle_t , uur::BoolTestParam>;
172150 using urMultiQueueLaunchMemcpyTest<Param>::context;
173151 using urMultiQueueLaunchMemcpyTest<Param>::queues;
@@ -176,26 +154,12 @@ struct urEnqueueKernelLaunchIncrementTest
176154 using urMultiQueueLaunchMemcpyTest<Param>::SharedMem;
177155
178156 void SetUp () override {
179- auto device = std::get<0 >(GetParam ());
180-
181- this ->createQueues = [&] {
182- ASSERT_SUCCESS (urQueueCreate (context, device, 0 , &queue));
183-
184- // use the same queue and device for all operations
185- for (size_t i = 0 ; i < numOps; i++) {
186- urQueueRetain (queue);
187-
188- queues.push_back (queue);
189- devices.push_back (device);
190- }
191- };
192-
193- UUR_RETURN_ON_FATAL_FAILURE (
194- urMultiQueueLaunchMemcpyTest<Param>::SetUp ());
157+ UUR_RETURN_ON_FATAL_FAILURE (urMultiQueueLaunchMemcpyTest<Param>::SetUp (
158+ std::vector<ur_device_handle_t >{std::get<0 >(GetParam ())},
159+ numOps)); // Use single device, duplicated numOps times
195160 }
196161
197162 void TearDown () override {
198- urQueueRelease (queue);
199163 UUR_RETURN_ON_FATAL_FAILURE (
200164 urMultiQueueLaunchMemcpyTest<Param>::TearDown ());
201165 }
@@ -219,6 +183,9 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
219183 ur_event_handle_t *kernelEvent = nullptr ;
220184 ur_event_handle_t *memcpyEvent = nullptr ;
221185
186+ // This is a single device test
187+ auto queue = queues[0 ];
188+
222189 for (size_t i = 0 ; i < numOps; i++) {
223190 if (useEvents) {
224191 lastMemcpyEvent = memcpyEvent;
0 commit comments