Skip to content

Commit 27dab6c

Browse files
AndreiZibrovAndrei Zibrovsteffenlarsen
authored
[SYCL] Fix has_extension string-based implementation (#19264)
This PR is requested to fix has_extension string comparison found during #19238 + testing to show all cases --------- Signed-off-by: Larsen, Steffen <[email protected]> Co-authored-by: Andrei Zibrov <[email protected]> Co-authored-by: Larsen, Steffen <[email protected]>
1 parent dde5462 commit 27dab6c

File tree

4 files changed

+154
-7
lines changed

4 files changed

+154
-7
lines changed

sycl/source/detail/device_impl.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,13 @@ device_impl::get_backend_info<info::device::backend_version>() const {
106106
#endif
107107

108108
bool device_impl::has_extension(const std::string &ExtensionName) const {
109-
std::string AllExtensionNames = get_info_impl<UR_DEVICE_INFO_EXTENSIONS>();
109+
const std::string AllExtensionNames{
110+
get_info_impl<UR_DEVICE_INFO_EXTENSIONS>()};
110111

111-
return (AllExtensionNames.find(ExtensionName) != std::string::npos);
112+
// We add a space to both sides of both the extension string and the query
113+
// string. This prevents to lookup from finding partial extension matches.
114+
return ((" " + AllExtensionNames + " ").find(" " + ExtensionName + " ") !=
115+
std::string::npos);
112116
}
113117

114118
bool device_impl::is_partition_supported(info::partition_property Prop) const {

sycl/unittests/buffer/BufferLocation.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,22 @@ static ur_result_t redefinedDeviceGetInfoAfter(void *pParams) {
6363
const size_t nameSize = name.size() + 1;
6464

6565
if (!*params->ppPropValue) {
66+
size_t beforeSize = **params->ppPropSizeRet;
6667
// Choose bigger size so that both original and redefined function
67-
// has enough memory for storing the extension string
68-
**params->ppPropSizeRet = nameSize > **params->ppPropSizeRet
69-
? nameSize
70-
: **params->ppPropSizeRet;
68+
// has enough memory for storing the extension string. If the original has
69+
// reported it has a non-empty string to report, we additionally need room
70+
// for a space.
71+
**params->ppPropSizeRet = beforeSize + (beforeSize > 0) + nameSize;
7172
} else {
73+
assert(*params->ppropSize >= nameSize);
74+
// Insert at the end of the extension string.
75+
size_t nameOffset = *params->ppropSize - nameSize;
7276
char *dst = static_cast<char *>(*params->ppPropValue);
73-
strcpy(dst, name.data());
77+
// If the offset isn't at the start of the string, we need to insert a
78+
// space before it.
79+
if (nameOffset > 0)
80+
dst[nameOffset - 1] = ' ';
81+
strcpy(dst + *params->ppropSize - nameSize, name.data());
7482
}
7583
break;
7684
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_sycl_unittest(ContextDeviceTests OBJECT
22
Context.cpp
33
DeviceRefCounter.cpp
4+
HasExtensionWordBoundary.cpp
45
)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
//==---- HasExtensionWordBoundary.cpp --- Test word boundary fix ----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This test verifies that has_extension correctly matches full extension names
10+
// and doesn't match partial substrings.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include <detail/device_impl.hpp>
15+
#include <gtest/gtest.h>
16+
#include <helpers/UrMock.hpp>
17+
#include <sycl/sycl.hpp>
18+
#include <ur_mock_helpers.hpp>
19+
20+
using namespace sycl;
21+
22+
thread_local std::string MockExtensions = "";
23+
24+
static ur_result_t redefinedDeviceGetInfo(void *pParams) {
25+
auto Params = *static_cast<ur_device_get_info_params_t *>(pParams);
26+
27+
if (*Params.ppropName == UR_DEVICE_INFO_EXTENSIONS) {
28+
// Override extensions query with mock data.
29+
if (*Params.ppPropValue) {
30+
size_t Len = MockExtensions.length() + 1;
31+
if (*Params.ppropSize >= Len)
32+
std::memcpy(*Params.ppPropValue, MockExtensions.c_str(), Len);
33+
}
34+
if (*Params.ppPropSizeRet)
35+
**Params.ppPropSizeRet = MockExtensions.length() + 1;
36+
37+
return UR_RESULT_SUCCESS;
38+
}
39+
40+
// Delegate to the default mock.
41+
return sycl::unittest::MockAdapter::mock_urDeviceGetInfo(pParams);
42+
}
43+
44+
class HasExtensionWordBoundaryTest : public ::testing::Test {
45+
public:
46+
HasExtensionWordBoundaryTest() : Mock{} {}
47+
48+
protected:
49+
void SetUp() override {
50+
mock::getCallbacks().set_replace_callback("urDeviceGetInfo",
51+
&redefinedDeviceGetInfo);
52+
}
53+
54+
sycl::unittest::UrMock<> Mock;
55+
};
56+
57+
TEST_F(HasExtensionWordBoundaryTest, ExactMatchWorks) {
58+
MockExtensions = "cl_khr_fp64 cl_intel_subgroups cl_khr_subgroups";
59+
60+
sycl::platform Plt{sycl::platform()};
61+
sycl::device Dev = Plt.get_devices()[0];
62+
63+
EXPECT_TRUE(Dev.has_extension("cl_khr_fp64"));
64+
EXPECT_TRUE(Dev.has_extension("cl_intel_subgroups"));
65+
EXPECT_TRUE(Dev.has_extension("cl_khr_subgroups"));
66+
}
67+
68+
TEST_F(HasExtensionWordBoundaryTest, SubstringDoesNotMatch) {
69+
MockExtensions = "cl_intel_subgroups cl_khr_fp64_extended";
70+
71+
sycl::platform Plt{sycl::platform()};
72+
sycl::device Dev = Plt.get_devices()[0];
73+
74+
EXPECT_FALSE(Dev.has_extension("cl_intel_subgroup"));
75+
EXPECT_FALSE(Dev.has_extension("cl_khr_fp64"));
76+
EXPECT_FALSE(Dev.has_extension("subgroups"));
77+
EXPECT_FALSE(Dev.has_extension("intel_subgroups"));
78+
}
79+
80+
TEST_F(HasExtensionWordBoundaryTest, EmptyExtensions) {
81+
MockExtensions = "";
82+
83+
sycl::platform Plt{sycl::platform()};
84+
sycl::device Dev = Plt.get_devices()[0];
85+
86+
EXPECT_FALSE(Dev.has_extension("cl_khr_fp64"));
87+
}
88+
89+
TEST_F(HasExtensionWordBoundaryTest, SingleExtension) {
90+
MockExtensions = "cl_khr_fp64";
91+
92+
sycl::platform Plt{sycl::platform()};
93+
sycl::device Dev = Plt.get_devices()[0];
94+
auto DevImpl = detail::getSyclObjImpl(Dev);
95+
96+
EXPECT_TRUE(Dev.has_extension("cl_khr_fp64"));
97+
EXPECT_FALSE(Dev.has_extension("cl_khr_fp6"));
98+
}
99+
100+
TEST_F(HasExtensionWordBoundaryTest, FirstMiddleLastExtensions) {
101+
MockExtensions = "cl_first_ext cl_middle_ext cl_last_ext";
102+
103+
sycl::platform Plt{sycl::platform()};
104+
sycl::device Dev = Plt.get_devices()[0];
105+
auto DevImpl = detail::getSyclObjImpl(Dev);
106+
107+
EXPECT_TRUE(Dev.has_extension("cl_first_ext"));
108+
EXPECT_TRUE(Dev.has_extension("cl_middle_ext"));
109+
EXPECT_TRUE(Dev.has_extension("cl_last_ext"));
110+
}
111+
112+
TEST_F(HasExtensionWordBoundaryTest, NonUniformGroupExtensions) {
113+
MockExtensions = "cl_khr_subgroup_non_uniform_vote "
114+
"cl_khr_subgroup_ballot "
115+
"cl_intel_subgroups "
116+
"cl_intel_spirv_subgroups "
117+
"cl_intel_subgroup_matrix_multiply_accumulate";
118+
119+
sycl::platform Plt{sycl::platform()};
120+
sycl::device Dev = Plt.get_devices()[0];
121+
auto DevImpl = detail::getSyclObjImpl(Dev);
122+
123+
EXPECT_TRUE(Dev.has_extension("cl_khr_subgroup_non_uniform_vote"));
124+
EXPECT_TRUE(Dev.has_extension("cl_khr_subgroup_ballot"));
125+
EXPECT_TRUE(Dev.has_extension("cl_intel_subgroups"));
126+
EXPECT_TRUE(Dev.has_extension("cl_intel_spirv_subgroups"));
127+
EXPECT_TRUE(
128+
Dev.has_extension("cl_intel_subgroup_matrix_multiply_accumulate"));
129+
130+
EXPECT_FALSE(Dev.has_extension("cl_khr_subgroup"));
131+
EXPECT_FALSE(Dev.has_extension("cl_intel_subgroup"));
132+
EXPECT_FALSE(Dev.has_extension("non_uniform_vote"));
133+
EXPECT_FALSE(Dev.has_extension("subgroup_matrix_multiply"));
134+
}

0 commit comments

Comments
 (0)