66#include " helpers.h"
77#include < uur/fixtures.h>
88
9+ using TestParametersMemcpy2D =
10+ std::tuple<uur::TestParameters2D, uur::USMKind, uur::USMKind>;
11+
912struct urEnqueueUSMMemcpy2DTestWithParam
10- : uur::urQueueTestWithParam<uur::TestParameters2D > {
13+ : uur::urQueueTestWithParam<TestParametersMemcpy2D > {
1114 void SetUp () override {
1215 UUR_RETURN_ON_FATAL_FAILURE (
13- uur::urQueueTestWithParam<uur::TestParameters2D>::SetUp ());
16+ uur::urQueueTestWithParam<TestParametersMemcpy2D>::SetUp ());
17+
18+ const auto [in2DParams, inSrcKind, inDstKind] = getParam ();
19+ std::tie (pitch, width, height, src_kind, dst_kind) =
20+ std::make_tuple (in2DParams.pitch , in2DParams.width ,
21+ in2DParams.height , inSrcKind, inDstKind);
22+
1423 ur_device_usm_access_capability_flags_t device_usm = 0 ;
1524 ASSERT_SUCCESS (uur::GetDeviceUSMDeviceSupport (device, device_usm));
16- if (!device_usm) {
25+ if (!device_usm && (src_kind == uur::USMKind::Device ||
26+ dst_kind == uur::USMKind::Device)) {
1727 GTEST_SKIP () << " Device USM is not supported" ;
1828 }
1929
@@ -25,15 +35,13 @@ struct urEnqueueUSMMemcpy2DTestWithParam
2535 GTEST_SKIP () << " 2D USM memcpy is not supported" ;
2636 }
2737
28- const auto [inPitch, inWidth, inHeight] = getParam ();
29- std::tie (pitch, width, height) =
30- std::make_tuple (inPitch, inWidth, inHeight);
31-
3238 const size_t num_elements = pitch * height;
33- ASSERT_SUCCESS (urUSMDeviceAlloc (context, device, nullptr , nullptr ,
34- num_elements, &pSrc));
35- ASSERT_SUCCESS (urUSMDeviceAlloc (context, device, nullptr , nullptr ,
36- num_elements, &pDst));
39+ ASSERT_SUCCESS (uur::MakeUSMAllocationByType (
40+ src_kind, context, device, nullptr , nullptr , num_elements, &pSrc));
41+
42+ ASSERT_SUCCESS (uur::MakeUSMAllocationByType (
43+ dst_kind, context, device, nullptr , nullptr , num_elements, &pDst));
44+
3745 ur_event_handle_t memset_event = nullptr ;
3846
3947 ASSERT_SUCCESS (urEnqueueUSMFill (queue, pSrc, sizeof (memset_value),
@@ -52,17 +60,22 @@ struct urEnqueueUSMMemcpy2DTestWithParam
5260 if (pDst) {
5361 ASSERT_SUCCESS (urUSMFree (context, pDst));
5462 }
55- uur::urQueueTestWithParam<uur::TestParameters2D >::TearDown ();
63+ uur::urQueueTestWithParam<TestParametersMemcpy2D >::TearDown ();
5664 }
5765
5866 void verifyMemcpySucceeded () {
5967 std::vector<uint8_t > host_mem (pitch * height);
60- ASSERT_SUCCESS (urEnqueueUSMMemcpy2D (queue, true , host_mem.data (), pitch,
61- pDst, pitch, width, height, 0 ,
62- nullptr , nullptr ));
68+ const uint8_t *host_ptr = nullptr ;
69+ if (dst_kind == uur::USMKind::Device) {
70+ ASSERT_SUCCESS (urEnqueueUSMMemcpy2D (queue, true , host_mem.data (),
71+ pitch, pDst, pitch, width,
72+ height, 0 , nullptr , nullptr ));
73+ host_ptr = host_mem.data ();
74+ } else {
75+ host_ptr = static_cast <const uint8_t *>(pDst);
76+ }
6377 for (size_t w = 0 ; w < width; ++w) {
6478 for (size_t h = 0 ; h < height; ++h) {
65- const auto *host_ptr = host_mem.data ();
6679 const size_t index = (pitch * h) + w;
6780 ASSERT_TRUE (*(host_ptr + index) == memset_value);
6881 }
@@ -75,9 +88,11 @@ struct urEnqueueUSMMemcpy2DTestWithParam
7588 size_t pitch = 0 ;
7689 size_t width = 0 ;
7790 size_t height = 0 ;
91+ uur::USMKind src_kind;
92+ uur::USMKind dst_kind;
7893};
7994
80- static std::vector<uur::TestParameters2D> test_cases {
95+ static std::vector<uur::TestParameters2D> test_sizes {
8196 /* Everything set to 1 */
8297 {1 , 1 , 1 },
8398 /* Height == 1 && Pitch > width */
@@ -92,7 +107,13 @@ static std::vector<uur::TestParameters2D> test_cases{
92107 {234 , 233 , 1 }};
93108
94109UUR_TEST_SUITE_P (urEnqueueUSMMemcpy2DTestWithParam,
95- ::testing::ValuesIn (test_cases),
110+ ::testing::Combine (::testing::ValuesIn(test_sizes),
111+ ::testing::Values(uur::USMKind::Device,
112+ uur::USMKind::Host,
113+ uur::USMKind::Shared),
114+ ::testing::Values(uur::USMKind::Device,
115+ uur::USMKind::Host,
116+ uur::USMKind::Shared)),
96117 uur::print2DTestString<urEnqueueUSMMemcpy2DTestWithParam>);
97118
98119TEST_P (urEnqueueUSMMemcpy2DTestWithParam, SuccessBlocking) {
@@ -119,7 +140,8 @@ TEST_P(urEnqueueUSMMemcpy2DTestWithParam, SuccessNonBlocking) {
119140
120141using urEnqueueUSMMemcpy2DNegativeTest = urEnqueueUSMMemcpy2DTestWithParam;
121142UUR_TEST_SUITE_P (urEnqueueUSMMemcpy2DNegativeTest,
122- ::testing::Values (uur::TestParameters2D{1 , 1 , 1 }),
143+ ::testing::Values (TestParametersMemcpy2D{
144+ {1 , 1 , 1 }, uur::USMKind::Device, uur::USMKind::Device}),
123145 uur::print2DTestString<urEnqueueUSMMemcpy2DTestWithParam>);
124146
125147TEST_P (urEnqueueUSMMemcpy2DNegativeTest, InvalidNullHandleQueue) {
0 commit comments