Skip to content

Commit ec607fb

Browse files
authored
[SYCL] Fix sampled bindless fetch_image for float type (#20107)
The implementation previously used `__spirv_ImageSampleExplicitLod` which returned incorrect results when the `DataT` return type was a vector of floats. This change makes it use `__spirv_ImageRead` instead and adds a test case for the `float` image type. Signed-off-by: Michael Aziz <[email protected]>
1 parent bb6e44c commit ec607fb

File tree

2 files changed

+45
-22
lines changed

2 files changed

+45
-22
lines changed

sycl/include/sycl/ext/oneapi/bindless_images.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,15 +1036,19 @@ DataT fetch_image(const sampled_image_handle &imageHandle [[maybe_unused]],
10361036
"HintT must always be a recognized standard type");
10371037

10381038
#ifdef __SYCL_DEVICE_ONLY__
1039+
// Convert the raw handle to an image and use FETCH_UNSAMPLED_IMAGE since
1040+
// fetch_image should not use the sampler
10391041
if constexpr (detail::is_recognized_standard_type<DataT>()) {
1040-
return FETCH_SAMPLED_IMAGE(
1042+
return FETCH_UNSAMPLED_IMAGE(
10411043
DataT,
1042-
CONVERT_HANDLE_TO_SAMPLED_IMAGE(imageHandle.raw_handle, coordSize),
1044+
CONVERT_HANDLE_TO_IMAGE(imageHandle.raw_handle,
1045+
detail::OCLImageTyRead<coordSize>),
10431046
coords);
10441047
} else {
1045-
return sycl::bit_cast<DataT>(FETCH_SAMPLED_IMAGE(
1048+
return sycl::bit_cast<DataT>(FETCH_UNSAMPLED_IMAGE(
10461049
HintT,
1047-
CONVERT_HANDLE_TO_SAMPLED_IMAGE(imageHandle.raw_handle, coordSize),
1050+
CONVERT_HANDLE_TO_IMAGE(imageHandle.raw_handle,
1051+
detail::OCLImageTyRead<coordSize>),
10481052
coords));
10491053
}
10501054
#else

sycl/test-e2e/bindless_images/sampled_fetch/fetch_2D_USM_device.cpp

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
#include <sycl/ext/oneapi/bindless_images.hpp>
1212
#include <sycl/usm.hpp>
1313

14-
class kernel_sampled_fetch;
14+
namespace {
1515

16-
int main() {
16+
template <typename T, sycl::image_channel_type ChanType>
17+
static int testSampledImageFetch() {
1718

1819
sycl::device dev;
1920
sycl::queue q(dev);
@@ -23,9 +24,9 @@ int main() {
2324
constexpr size_t width = 5;
2425
constexpr size_t height = 6;
2526
constexpr size_t N = width * height;
26-
std::vector<sycl::vec<uint16_t, 4>> out(N);
27-
std::vector<sycl::vec<uint16_t, 4>> expected(N);
28-
std::vector<sycl::vec<uint16_t, 4>> dataIn(N);
27+
std::vector<sycl::vec<T, 4>> out(N);
28+
std::vector<sycl::vec<T, 4>> expected(N);
29+
std::vector<sycl::vec<T, 4>> dataIn(N);
2930
for (int i = 0; i < width; i++) {
3031
for (int j = 0; j < height; j++) {
3132
auto index = i + (width * j);
@@ -43,8 +44,7 @@ int main() {
4344
sycl::filtering_mode::linear);
4445

4546
// Extension: image descriptor
46-
syclexp::image_descriptor desc({width, height}, 4,
47-
sycl::image_channel_type::unsigned_int16);
47+
syclexp::image_descriptor desc({width, height}, 4, ChanType);
4848
size_t pitch = 0;
4949

5050
// Extension: returns the device pointer to USM allocated pitched memory
@@ -65,21 +65,20 @@ int main() {
6565

6666
sycl::buffer buf(out.data(), sycl::range{height, width});
6767
q.submit([&](sycl::handler &cgh) {
68-
auto outAcc = buf.get_access<sycl::access_mode::write>(
68+
auto outAcc = buf.template get_access<sycl::access_mode::write>(
6969
cgh, sycl::range<2>{height, width});
7070

71-
cgh.parallel_for<kernel_sampled_fetch>(
72-
sycl::nd_range<2>{{width, height}, {width, height}},
73-
[=](sycl::nd_item<2> it) {
74-
size_t dim0 = it.get_local_id(0);
75-
size_t dim1 = it.get_local_id(1);
71+
cgh.parallel_for(sycl::nd_range<2>{{width, height}, {width, height}},
72+
[=](sycl::nd_item<2> it) {
73+
size_t dim0 = it.get_local_id(0);
74+
size_t dim1 = it.get_local_id(1);
7675

77-
// Extension: fetch data from sampled image handle
78-
auto px1 = syclexp::fetch_image<sycl::vec<uint16_t, 4>>(
79-
imgHandle, sycl::int2(dim0, dim1));
76+
// Extension: fetch data from sampled image handle
77+
auto px1 = syclexp::fetch_image<sycl::vec<T, 4>>(
78+
imgHandle, sycl::int2(dim0, dim1));
8079

81-
outAcc[sycl::id<2>{dim1, dim0}] = px1;
82-
});
80+
outAcc[sycl::id<2>{dim1, dim0}] = px1;
81+
});
8382
});
8483

8584
q.wait_and_throw();
@@ -121,3 +120,23 @@ int main() {
121120
std::cout << "Test failed!" << std::endl;
122121
return 3;
123122
}
123+
124+
} // namespace
125+
126+
int main() {
127+
if (int err =
128+
testSampledImageFetch<uint16_t,
129+
sycl::image_channel_type::unsigned_int16>()) {
130+
return err;
131+
}
132+
if (int err =
133+
testSampledImageFetch<uint32_t,
134+
sycl::image_channel_type::unsigned_int32>()) {
135+
return err;
136+
}
137+
if (int err =
138+
testSampledImageFetch<float, sycl::image_channel_type::fp32>()) {
139+
return err;
140+
}
141+
return 0;
142+
}

0 commit comments

Comments
 (0)