Skip to content

Commit ab2da2d

Browse files
authored
[ESIMD]Add static_asserts to existing autodeduction API (#13977)
1 parent 43f6332 commit ab2da2d

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

sycl/include/sycl/ext/intel/esimd/memory.hpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,10 @@ __ESIMD_API std::enable_if_t<
623623
simd<T, N>>
624624
gather(const T *p, OffsetSimdViewT byte_offsets, simd_mask<N / VS> mask,
625625
PassThruSimdViewT pass_thru, PropertyListT props = {}) {
626+
static_assert(N / VS ==
627+
OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(),
628+
"Size of pass_thru parameter must correspond to the size of "
629+
"byte_offsets parameter.");
626630
return gather<T, N, VS>(p, byte_offsets.read(), mask, pass_thru.read(),
627631
props);
628632
}
@@ -662,6 +666,10 @@ __ESIMD_API std::enable_if_t<
662666
simd<T, N>>
663667
gather(const T *p, OffsetSimdViewT byte_offsets, simd_mask<N / VS> mask,
664668
simd<T, N> pass_thru, PropertyListT props = {}) {
669+
static_assert(N / VS ==
670+
OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(),
671+
"Size of pass_thru parameter must correspond to the size of "
672+
"byte_offsets parameter.");
665673
return gather<T, N, VS>(p, byte_offsets.read(), mask, pass_thru, props);
666674
}
667675

@@ -1012,6 +1020,10 @@ __ESIMD_API std::enable_if_t<
10121020
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
10131021
scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals,
10141022
simd_mask<N / VS> mask, PropertyListT props = {}) {
1023+
static_assert(N / VS ==
1024+
OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(),
1025+
"Size of vals parameter must correspond to the size of "
1026+
"byte_offsets parameter.");
10151027
scatter<T, N, VS>(p, byte_offsets.read(), vals.read(), mask, props);
10161028
}
10171029

@@ -1116,6 +1128,10 @@ __ESIMD_API std::enable_if_t<
11161128
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
11171129
scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
11181130
simd_mask<N / VS> mask, PropertyListT props = {}) {
1131+
static_assert(N / VS ==
1132+
OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(),
1133+
"Size of vals parameter must correspond to the size of "
1134+
"byte_offsets parameter.");
11191135
scatter<T, N, VS>(p, byte_offsets.read(), vals, mask, props);
11201136
}
11211137

@@ -1150,6 +1166,10 @@ __ESIMD_API std::enable_if_t<
11501166
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
11511167
scatter(T *p, OffsetSimdViewT byte_offsets, simd<T, N> vals,
11521168
PropertyListT props = {}) {
1169+
static_assert(N / VS ==
1170+
OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(),
1171+
"Size of vals parameter must correspond to the size of "
1172+
"byte_offsets parameter.");
11531173
scatter<T, N, VS>(p, byte_offsets.read(), vals, props);
11541174
}
11551175

@@ -1221,8 +1241,11 @@ __ESIMD_API std::enable_if_t<
12211241
ext::oneapi::experimental::is_property_list_v<PropertyListT>>
12221242
scatter(T *p, OffsetSimdViewT byte_offsets, ValuesSimdViewT vals,
12231243
PropertyListT props = {}) {
1224-
simd_mask<N / VS> Mask = 1;
1225-
scatter<T, N, VS>(p, byte_offsets.read(), vals.read(), Mask, props);
1244+
static_assert(N / VS ==
1245+
OffsetSimdViewT::getSizeX() * OffsetSimdViewT::getSizeY(),
1246+
"Size of vals parameter must correspond to the size of "
1247+
"byte_offsets parameter.");
1248+
scatter<T, N, VS>(p, byte_offsets.read(), vals.read(), props);
12261249
}
12271250

12281251
/// A variation of \c scatter API with \c offsets represented as scalar.

0 commit comments

Comments
 (0)