Skip to content

Commit 51860b3

Browse files
committed
Add bfloat16-container support to clamp
i.e. sycl::marray<bfloat16, N>, sycl::vec<bfloat16, N>
1 parent 5b0e0b5 commit 51860b3

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

sycl/include/syclcompat/math.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,28 @@ inline sycl::ext::oneapi::bfloat16 clamp(sycl::ext::oneapi::bfloat16 val,
8181
return max_val;
8282
return val;
8383
}
84+
85+
template <typename T, int Size>
86+
inline std::enable_if_t<std::is_same_v<T, sycl::ext::oneapi::bfloat16>,
87+
sycl::vec<T, Size>>
88+
clamp(sycl::vec<T, Size> val, sycl::vec<T, Size> min_val,
89+
sycl::vec<T, Size> max_val) {
90+
return [&val, &min_val, &max_val]<int... I>(std::integer_sequence<int, I...>) {
91+
return sycl::vec<T, Size>{
92+
clamp<sycl::ext::oneapi::bfloat16>(val[I], min_val[I], max_val[I])...};
93+
}(std::make_integer_sequence<int, Size>{});
94+
}
95+
96+
template <typename T, std::size_t Size>
97+
inline std::enable_if_t<std::is_same_v<T, sycl::ext::oneapi::bfloat16>,
98+
sycl::marray<T, Size>>
99+
clamp(sycl::marray<T, Size> val, sycl::marray<T, Size> min_val,
100+
sycl::marray<T, Size> max_val) {
101+
return [&val, &min_val, &max_val]<std::size_t... I>(std::index_sequence<I...>) {
102+
return sycl::marray<T, Size>{
103+
clamp<sycl::ext::oneapi::bfloat16>(val[I], min_val[I], max_val[I])...};
104+
}(std::make_index_sequence<Size>{});
105+
}
84106
#endif
85107

86108
template <typename VecT, class BinaryOperation, class = void>

sycl/test-e2e/syclcompat/math/math_ops.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,24 @@ template <typename ValueT> void test_clamp() {
317317
.template launch_test<clamp_kernel<ValueT>>(op3, expect3);
318318
}
319319

320+
template <template <typename T, int Dim> typename ContainerT, typename ValueT> void test_container_clamp() {
321+
std::cout << __PRETTY_FUNCTION__ << std::endl;
322+
323+
constexpr syclcompat::dim3 grid{1};
324+
constexpr syclcompat::dim3 threads{1};
325+
ValueT op1 = static_cast<ValueT>(7);
326+
ValueT expect1 = static_cast<ValueT>(7);
327+
328+
ValueT op2 = static_cast<ValueT>(MAX_CLAMP + 1);
329+
ValueT expect2 = static_cast<ValueT>(MAX_CLAMP);
330+
331+
using ContT = ContainerT<ValueT, 2>;
332+
const ContT op4{op1, op2};
333+
const ContT expect4{expect1, expect2};
334+
UnaryOpTestLauncher<ContT>(grid, threads)
335+
.template launch_test<clamp_kernel<ContT>>(op4, expect4);
336+
}
337+
320338
int main() {
321339
INSTANTIATE_ALL_TYPES(value_type_list, test_syclcompat_max);
322340
INSTANTIATE_ALL_TYPES(value_type_list, test_syclcompat_min);
@@ -356,6 +374,8 @@ int main() {
356374
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::marray, test_isnan);
357375

358376
INSTANTIATE_ALL_TYPES(value_type_list, test_clamp);
377+
INSTANTIATE_ALL_CONTAINER_TYPES(vec_type_list, sycl::vec, test_container_clamp);
378+
INSTANTIATE_ALL_CONTAINER_TYPES(marray_type_list, sycl::marray, test_container_clamp);
359379

360380
return 0;
361381
}

0 commit comments

Comments
 (0)