22// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
33// See LICENSE.TXT
44// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5-
65#include < uur/fixtures.h>
76
8- struct urUSMDeviceAllocTest : uur::urQueueTestWithParam<uur::BoolTestParam> {
7+ using USMDeviceAllocParams = std::tuple<uur::BoolTestParam, uint32_t , size_t >;
8+
9+ template <typename T>
10+ inline std::string printUSMDeviceAllocTestString (
11+ const testing::TestParamInfo<typename T::ParamType> &info) {
12+ // ParamType will be std::tuple<ur_device_handle_t, USMDeviceAllocParams>
13+ const auto device_handle = std::get<0 >(info.param );
14+ const auto platform_device_name =
15+ uur::GetPlatformAndDeviceName (device_handle);
16+ const auto usmDeviceAllocParams = std::get<1 >(info.param );
17+ const auto BoolParam = std::get<0 >(usmDeviceAllocParams);
18+
19+ std::stringstream ss;
20+ ss << BoolParam.name << (BoolParam.value ? " Enabled" : " Disabled" );
21+
22+ const auto alignment = std::get<1 >(usmDeviceAllocParams);
23+ const auto size = std::get<2 >(usmDeviceAllocParams);
24+ if (alignment && size > 0 ) {
25+ ss << " _" ;
26+ ss << std::get<1 >(usmDeviceAllocParams);
27+ ss << " _" ;
28+ ss << std::get<2 >(usmDeviceAllocParams);
29+ }
30+
31+ return platform_device_name + " __" + ss.str ();
32+ }
33+
34+ struct urUSMDeviceAllocTest : uur::urQueueTestWithParam<USMDeviceAllocParams> {
935 void SetUp () override {
1036 UUR_RETURN_ON_FATAL_FAILURE (
11- uur::urQueueTestWithParam<uur::BoolTestParam >::SetUp ());
37+ uur::urQueueTestWithParam<USMDeviceAllocParams >::SetUp ());
1238 ur_device_usm_access_capability_flags_t deviceUSMSupport = 0 ;
1339 ASSERT_SUCCESS (
1440 uur::GetDeviceUSMDeviceSupport (device, deviceUSMSupport));
1541 if (!deviceUSMSupport) {
1642 GTEST_SKIP () << " Device USM is not supported." ;
1743 }
1844
19- if (getParam (). value ) {
45+ if (usePool ) {
2046 ur_usm_pool_desc_t pool_desc = {};
2147 ASSERT_SUCCESS (urUSMPoolCreate (context, &pool_desc, &pool));
2248 }
@@ -27,16 +53,20 @@ struct urUSMDeviceAllocTest : uur::urQueueTestWithParam<uur::BoolTestParam> {
2753 ASSERT_SUCCESS (urUSMPoolRelease (pool));
2854 }
2955 UUR_RETURN_ON_FATAL_FAILURE (
30- uur::urQueueTestWithParam<uur::BoolTestParam >::TearDown ());
56+ uur::urQueueTestWithParam<USMDeviceAllocParams >::TearDown ());
3157 }
3258
3359 ur_usm_pool_handle_t pool = nullptr ;
60+ bool usePool = std::get<0 >(getParam()).value;
3461};
3562
63+ // The 0 value parameters are not relevant for urUSMDeviceAllocTest tests, they are used below in urUSMDeviceAllocAlignmentTest
3664UUR_TEST_SUITE_P (
3765 urUSMDeviceAllocTest,
38- testing::ValuesIn (uur::BoolTestParam::makeBoolParam(" UsePool" )),
39- uur::deviceTestWithParamPrinter<uur::BoolTestParam>);
66+ testing::Combine (
67+ testing::ValuesIn (uur::BoolTestParam::makeBoolParam(" UsePool" )),
68+ testing::Values(0 ), testing::Values(0 )),
69+ printUSMDeviceAllocTestString<urUSMDeviceAllocTest>);
4070
4171TEST_P (urUSMDeviceAllocTest, Success) {
4272 void *ptr = nullptr ;
@@ -69,6 +99,7 @@ TEST_P(urUSMDeviceAllocTest, SuccessWithDescriptors) {
6999 size_t allocation_size = sizeof (int );
70100 ASSERT_SUCCESS (urUSMDeviceAlloc (context, device, &usm_desc, pool,
71101 allocation_size, &ptr));
102+ ASSERT_NE (ptr, nullptr );
72103
73104 ur_event_handle_t event = nullptr ;
74105 uint8_t pattern = 0 ;
@@ -116,3 +147,38 @@ TEST_P(urUSMDeviceAllocTest, InvalidValueAlignPowerOfTwo) {
116147 UR_RESULT_ERROR_INVALID_VALUE,
117148 urUSMDeviceAlloc (context, device, &desc, pool, sizeof (int ), &ptr));
118149}
150+
151+ using urUSMDeviceAllocAlignmentTest = urUSMDeviceAllocTest;
152+
153+ UUR_TEST_SUITE_P (
154+ urUSMDeviceAllocAlignmentTest,
155+ testing::Combine (
156+ testing::ValuesIn (uur::BoolTestParam::makeBoolParam(" UsePool" )),
157+ testing::Values(4 , 8 , 16 , 32 , 64 ), testing::Values(8 , 512 , 2048 )),
158+ printUSMDeviceAllocTestString<urUSMDeviceAllocAlignmentTest>);
159+
160+ TEST_P (urUSMDeviceAllocAlignmentTest, SuccessAlignedAllocations) {
161+ uint32_t alignment = std::get<1 >(getParam ());
162+ size_t allocation_size = std::get<2 >(getParam ());
163+
164+ ur_usm_device_desc_t usm_device_desc{UR_STRUCTURE_TYPE_USM_DEVICE_DESC,
165+ nullptr ,
166+ /* device flags */ 0 };
167+
168+ ur_usm_desc_t usm_desc{UR_STRUCTURE_TYPE_USM_DESC, &usm_device_desc,
169+ /* mem advice flags */ UR_USM_ADVICE_FLAG_DEFAULT,
170+ alignment};
171+ void *ptr = nullptr ;
172+ ASSERT_SUCCESS (urUSMDeviceAlloc (context, device, &usm_desc, pool,
173+ allocation_size, &ptr));
174+ ASSERT_NE (ptr, nullptr );
175+
176+ ur_event_handle_t event = nullptr ;
177+ uint8_t pattern = 0 ;
178+ ASSERT_SUCCESS (urEnqueueUSMFill (queue, ptr, sizeof (pattern), &pattern,
179+ allocation_size, 0 , nullptr , &event));
180+ ASSERT_SUCCESS (urEventWait (1 , &event));
181+
182+ ASSERT_SUCCESS (urUSMFree (context, ptr));
183+ EXPECT_SUCCESS (urEventRelease (event));
184+ }
0 commit comments