Skip to content

Commit b05f256

Browse files
authored
[SYCL][ESIMD] Limit allowed argument types for rol/ror functions (#6569)
* Limit allowed argument types for rol/ror functions
1 parent 7805aa3 commit b05f256

File tree

1 file changed

+52
-30
lines changed
  • sycl/include/sycl/ext/intel/experimental/esimd

1 file changed

+52
-30
lines changed

sycl/include/sycl/ext/intel/experimental/esimd/math.hpp

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,13 @@ shr(T1 src0, T2 src1, Sat sat = {}) {
276276
/// the input vector \p src0 shall be rotated.
277277
/// @return vector of rotated elements.
278278
template <typename T0, typename T1, int SZ>
279-
__ESIMD_API
280-
std::enable_if_t<std::is_integral<T0>::value && std::is_integral<T1>::value,
281-
__ESIMD_NS::simd<T0, SZ>>
282-
rol(__ESIMD_NS::simd<T1, SZ> src0, __ESIMD_NS::simd<T1, SZ> src1) {
279+
__ESIMD_API std::enable_if_t<
280+
__ESIMD_NS::detail::is_type<T0, int16_t, uint16_t, int32_t, uint32_t,
281+
int64_t, uint64_t>() &&
282+
__ESIMD_NS::detail::is_type<T1, int16_t, uint16_t, int32_t, uint32_t,
283+
int64_t, uint64_t>(),
284+
__ESIMD_NS::simd<T0, SZ>>
285+
rol(__ESIMD_NS::simd<T1, SZ> src0, __ESIMD_NS::simd<T1, SZ> src1) {
283286
return __esimd_rol<T0, T1, SZ>(src0.data(), src1.data());
284287
}
285288

@@ -292,10 +295,14 @@ __ESIMD_API
292295
/// @param src1 the number of bit positions the input vector shall be rotated.
293296
/// @return vector of rotated elements.
294297
template <typename T0, typename T1, int SZ, typename U>
295-
__ESIMD_API std::enable_if_t<std::is_integral<T0>::value &&
296-
std::is_integral<T1>::value &&
297-
std::is_integral<U>::value,
298-
__ESIMD_NS::simd<T0, SZ>>
298+
__ESIMD_API std::enable_if_t<
299+
__ESIMD_NS::detail::is_type<T0, int16_t, uint16_t, int32_t, uint32_t,
300+
int64_t, uint64_t>() &&
301+
__ESIMD_NS::detail::is_type<T1, int16_t, uint16_t, int32_t, uint32_t,
302+
int64_t, uint64_t>() &&
303+
__ESIMD_NS::detail::is_type<U, int16_t, uint16_t, int32_t, uint32_t,
304+
int64_t, uint64_t>(),
305+
__ESIMD_NS::simd<T0, SZ>>
299306
rol(__ESIMD_NS::simd<T1, SZ> src0, U src1) {
300307
__ESIMD_NS::simd<T1, SZ> Src1 = src1;
301308
return esimd::rol<T0>(src0, Src1);
@@ -309,13 +316,17 @@ rol(__ESIMD_NS::simd<T1, SZ> src0, U src1) {
309316
/// @param src1 the number of bit positions the input vector shall be rotated.
310317
/// @return rotated left value.
311318
template <typename T0, typename T1, typename T2>
312-
__ESIMD_API std::enable_if_t<__ESIMD_DNS::is_esimd_scalar<T0>::value &&
313-
__ESIMD_DNS::is_esimd_scalar<T1>::value &&
314-
__ESIMD_DNS::is_esimd_scalar<T2>::value &&
315-
std::is_integral<T0>::value &&
316-
std::is_integral<T1>::value &&
317-
std::is_integral<T2>::value,
318-
std::remove_const_t<T0>>
319+
__ESIMD_API std::enable_if_t<
320+
__ESIMD_DNS::is_esimd_scalar<T0>::value &&
321+
__ESIMD_DNS::is_esimd_scalar<T1>::value &&
322+
__ESIMD_DNS::is_esimd_scalar<T2>::value &&
323+
__ESIMD_NS::detail::is_type<T0, int16_t, uint16_t, int32_t, uint32_t,
324+
int64_t, uint64_t>() &&
325+
__ESIMD_NS::detail::is_type<T1, int16_t, uint16_t, int32_t, uint32_t,
326+
int64_t, uint64_t>() &&
327+
__ESIMD_NS::detail::is_type<T2, int16_t, uint16_t, int32_t, uint32_t,
328+
int64_t, uint64_t>(),
329+
std::remove_const_t<T0>>
319330
rol(T1 src0, T2 src1) {
320331
__ESIMD_NS::simd<T1, 1> Src0 = src0;
321332
__ESIMD_NS::simd<T0, 1> Result = esimd::rol<T0, T1, 1, T2>(Src0, src1);
@@ -331,10 +342,13 @@ rol(T1 src0, T2 src1) {
331342
/// the input vector \p src0 shall be rotated.
332343
/// @return vector of rotated elements.
333344
template <typename T0, typename T1, int SZ>
334-
__ESIMD_API
335-
std::enable_if_t<std::is_integral<T0>::value && std::is_integral<T1>::value,
336-
__ESIMD_NS::simd<T0, SZ>>
337-
ror(__ESIMD_NS::simd<T1, SZ> src0, __ESIMD_NS::simd<T1, SZ> src1) {
345+
__ESIMD_API std::enable_if_t<
346+
__ESIMD_NS::detail::is_type<T0, int16_t, uint16_t, int32_t, uint32_t,
347+
int64_t, uint64_t>() &&
348+
__ESIMD_NS::detail::is_type<T1, int16_t, uint16_t, int32_t, uint32_t,
349+
int64_t, uint64_t>(),
350+
__ESIMD_NS::simd<T0, SZ>>
351+
ror(__ESIMD_NS::simd<T1, SZ> src0, __ESIMD_NS::simd<T1, SZ> src1) {
338352
return __esimd_ror<T0, T1, SZ>(src0.data(), src1.data());
339353
}
340354

@@ -347,10 +361,14 @@ __ESIMD_API
347361
/// @param src1 the number of bit positions the input vector shall be rotated.
348362
/// @return vector of rotated elements.
349363
template <typename T0, typename T1, int SZ, typename U>
350-
__ESIMD_API std::enable_if_t<std::is_integral<T0>::value &&
351-
std::is_integral<T1>::value &&
352-
std::is_integral<U>::value,
353-
__ESIMD_NS::simd<T0, SZ>>
364+
__ESIMD_API std::enable_if_t<
365+
__ESIMD_NS::detail::is_type<T0, int16_t, uint16_t, int32_t, uint32_t,
366+
int64_t, uint64_t>() &&
367+
__ESIMD_NS::detail::is_type<T1, int16_t, uint16_t, int32_t, uint32_t,
368+
int64_t, uint64_t>() &&
369+
__ESIMD_NS::detail::is_type<U, int16_t, uint16_t, int32_t, uint32_t,
370+
int64_t, uint64_t>(),
371+
__ESIMD_NS::simd<T0, SZ>>
354372
ror(__ESIMD_NS::simd<T1, SZ> src0, U src1) {
355373
__ESIMD_NS::simd<T1, SZ> Src1 = src1;
356374
return esimd::ror<T0>(src0, Src1);
@@ -364,13 +382,17 @@ ror(__ESIMD_NS::simd<T1, SZ> src0, U src1) {
364382
/// @param src1 the number of bit positions the input vector shall be rotated.
365383
/// @return rotated right value.
366384
template <typename T0, typename T1, typename T2>
367-
__ESIMD_API std::enable_if_t<__ESIMD_DNS::is_esimd_scalar<T0>::value &&
368-
__ESIMD_DNS::is_esimd_scalar<T1>::value &&
369-
__ESIMD_DNS::is_esimd_scalar<T2>::value &&
370-
std::is_integral<T0>::value &&
371-
std::is_integral<T1>::value &&
372-
std::is_integral<T2>::value,
373-
std::remove_const_t<T0>>
385+
__ESIMD_API std::enable_if_t<
386+
__ESIMD_DNS::is_esimd_scalar<T0>::value &&
387+
__ESIMD_DNS::is_esimd_scalar<T1>::value &&
388+
__ESIMD_DNS::is_esimd_scalar<T2>::value &&
389+
__ESIMD_NS::detail::is_type<T0, int16_t, uint16_t, int32_t, uint32_t,
390+
int64_t, uint64_t>() &&
391+
__ESIMD_NS::detail::is_type<T1, int16_t, uint16_t, int32_t, uint32_t,
392+
int64_t, uint64_t>() &&
393+
__ESIMD_NS::detail::is_type<T2, int16_t, uint16_t, int32_t, uint32_t,
394+
int64_t, uint64_t>(),
395+
std::remove_const_t<T0>>
374396
ror(T1 src0, T2 src1) {
375397
__ESIMD_NS::simd<T1, 1> Src0 = src0;
376398
__ESIMD_NS::simd<T0, 1> Result = esimd::ror<T0, T1, 1, T2>(Src0, src1);

0 commit comments

Comments
 (0)