Skip to content

Commit 5a6085f

Browse files
authored
Merge branch 'intel:sycl' into work_group_memoy_new
2 parents 1783f75 + b2f6326 commit 5a6085f

File tree

10 files changed

+172
-87
lines changed

10 files changed

+172
-87
lines changed

sycl/include/sycl/accessor_image.hpp

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,21 @@
1212
namespace sycl {
1313
inline namespace _V1 {
1414
namespace detail {
15-
template <int Dim, typename T> struct IsValidCoordDataT;
16-
template <typename T> struct IsValidCoordDataT<1, T> {
17-
constexpr static bool value = detail::is_contained<
18-
T, detail::type_list<opencl::cl_int, opencl::cl_float>>::type::value;
15+
template <int Dim, typename T, bool AllowFP = true> struct IsValidCoordDataT;
16+
template <typename T, bool AllowFP> struct IsValidCoordDataT<1, T, AllowFP> {
17+
constexpr static bool value =
18+
std::is_same_v<T, opencl::cl_int> ||
19+
(AllowFP && std::is_same_v<T, opencl::cl_float>);
1920
};
20-
template <typename T> struct IsValidCoordDataT<2, T> {
21-
constexpr static bool value = detail::is_contained<
22-
T, detail::type_list<vec<opencl::cl_int, 2>,
23-
vec<opencl::cl_float, 2>>>::type::value;
21+
template <typename T, bool AllowFP> struct IsValidCoordDataT<2, T, AllowFP> {
22+
constexpr static bool value =
23+
std::is_same_v<T, vec<opencl::cl_int, 2>> ||
24+
(AllowFP && std::is_same_v<T, vec<opencl::cl_float, 2>>);
2425
};
25-
template <typename T> struct IsValidCoordDataT<3, T> {
26-
constexpr static bool value = detail::is_contained<
27-
T, detail::type_list<vec<opencl::cl_int, 4>,
28-
vec<opencl::cl_float, 4>>>::type::value;
26+
template <typename T, bool AllowFP> struct IsValidCoordDataT<3, T, AllowFP> {
27+
constexpr static bool value =
28+
std::is_same_v<T, vec<opencl::cl_int, 4>> ||
29+
(AllowFP && std::is_same_v<T, vec<opencl::cl_float, 4>>);
2930
};
3031

3132
template <int Dim, typename T> struct IsValidUnsampledCoord2020DataT;
@@ -448,12 +449,12 @@ class image_accessor
448449
// (accessTarget == access::target::image && accessMode == access::mode::read)
449450
// || (accessTarget == access::target::host_image && ( accessMode ==
450451
// access::mode::read || accessMode == access::mode::read_write))
451-
template <typename CoordT, int Dims = Dimensions,
452-
typename = std::enable_if_t<
453-
(Dims > 0) && (IsValidCoordDataT<Dims, CoordT>::value) &&
454-
(detail::is_genint_v<CoordT>) &&
455-
((IsImageAcc && IsImageAccessReadOnly) ||
456-
(IsHostImageAcc && IsImageAccessAnyRead))>>
452+
template <
453+
typename CoordT, int Dims = Dimensions,
454+
typename = std::enable_if_t<
455+
(IsValidCoordDataT<Dims, CoordT, /* AllowFP = */ false>::value) &&
456+
((IsImageAcc && IsImageAccessReadOnly) ||
457+
(IsHostImageAcc && IsImageAccessAnyRead))>>
457458
DataT read(const CoordT &Coords) const {
458459
#ifdef __SYCL_DEVICE_ONLY__
459460
return __invoke__ImageRead<DataT, OCLImageTy, CoordT>(MImageObj, Coords);
@@ -470,7 +471,7 @@ class image_accessor
470471
// access::mode::read || accessMode == access::mode::read_write))
471472
template <typename CoordT, int Dims = Dimensions,
472473
typename = std::enable_if_t<
473-
(Dims > 0) && (IsValidCoordDataT<Dims, CoordT>::value) &&
474+
(IsValidCoordDataT<Dims, CoordT>::value) &&
474475
((IsImageAcc && IsImageAccessReadOnly) ||
475476
(IsHostImageAcc && IsImageAccessAnyRead))>>
476477
DataT read(const CoordT &Coords, const sampler &Smpl) const {
@@ -494,10 +495,10 @@ class image_accessor
494495
// accessMode == access::mode::read_write))
495496
template <
496497
typename CoordT, int Dims = Dimensions,
497-
typename = std::enable_if_t<(Dims > 0) && (detail::is_genint_v<CoordT>) &&
498-
(IsValidCoordDataT<Dims, CoordT>::value) &&
499-
((IsImageAcc && IsImageAccessWriteOnly) ||
500-
(IsHostImageAcc && IsImageAccessAnyWrite))>>
498+
typename = std::enable_if_t<
499+
(IsValidCoordDataT<Dims, CoordT, /* AllowFP = */ false>::value) &&
500+
((IsImageAcc && IsImageAccessWriteOnly) ||
501+
(IsHostImageAcc && IsImageAccessAnyWrite))>>
501502
void write(const CoordT &Coords, const DataT &Color) const {
502503
#ifdef __SYCL_DEVICE_ONLY__
503504
__invoke__ImageWrite<OCLImageTy, CoordT, DataT>(MImageObj, Coords, Color);
@@ -546,23 +547,21 @@ class __image_array_slice__ {
546547
size_t Idx)
547548
: MBaseAcc(BaseAcc), MIdx(Idx) {}
548549

549-
template <typename CoordT, int Dims = Dimensions,
550-
typename = std::enable_if_t<
551-
(Dims > 0) && (IsValidCoordDataT<Dims, CoordT>::value)>>
550+
template <
551+
typename CoordT, int Dims = Dimensions,
552+
typename = std::enable_if_t<(IsValidCoordDataT<Dims, CoordT>::value)>>
552553
DataT read(const CoordT &Coords) const {
553554
return MBaseAcc.read(getAdjustedCoords(Coords));
554555
}
555556

556557
template <typename CoordT, int Dims = Dimensions,
557-
typename = std::enable_if_t<(Dims > 0) &&
558-
IsValidCoordDataT<Dims, CoordT>::value>>
558+
typename = std::enable_if_t<IsValidCoordDataT<Dims, CoordT>::value>>
559559
DataT read(const CoordT &Coords, const sampler &Smpl) const {
560560
return MBaseAcc.read(getAdjustedCoords(Coords), Smpl);
561561
}
562562

563563
template <typename CoordT, int Dims = Dimensions,
564-
typename = std::enable_if_t<(Dims > 0) &&
565-
IsValidCoordDataT<Dims, CoordT>::value>>
564+
typename = std::enable_if_t<IsValidCoordDataT<Dims, CoordT>::value>>
566565
void write(const CoordT &Coords, const DataT &Color) const {
567566
return MBaseAcc.write(getAdjustedCoords(Coords), Color);
568567
}

sycl/include/sycl/detail/generic_type_lists.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ using scalar_vector_bfloat16_list =
7070
using bfloat16_list =
7171
tl_append<scalar_bfloat16_list, vector_bfloat16_list, marray_bfloat16_list>;
7272

73-
using half_bfloat16_list = tl_append<scalar_half_list, scalar_bfloat16_list>;
74-
7573
using scalar_float_list = type_list<float>;
7674

7775
using vector_float_list =

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,6 @@ template <typename T>
3131
inline constexpr bool is_svgenfloatf_v =
3232
is_contained_v<T, gtl::scalar_vector_float_list>;
3333

34-
template <typename T>
35-
inline constexpr bool is_half_v = is_contained_v<T, gtl::scalar_half_list>;
36-
37-
template <typename T>
38-
inline constexpr bool is_bfloat16_v =
39-
is_contained_v<T, gtl::scalar_bfloat16_list>;
40-
41-
template <typename T>
42-
inline constexpr bool is_half_or_bf16_v =
43-
is_contained_v<T, gtl::half_bfloat16_list>;
44-
4534
template <typename T>
4635
inline constexpr bool is_svgenfloath_v =
4736
is_contained_v<T, gtl::scalar_vector_half_list>;
@@ -57,9 +46,6 @@ template <typename T>
5746
inline constexpr bool is_vgenfloat_v =
5847
is_contained_v<T, gtl::vector_floating_list>;
5948

60-
template <typename T>
61-
inline constexpr bool is_genint_v = is_contained_v<T, gtl::signed_int_list>;
62-
6349
template <typename T>
6450
inline constexpr bool is_geninteger_v = is_contained_v<T, gtl::integer_list>;
6551

@@ -141,10 +127,11 @@ template <typename T> auto convertToOpenCLType(T &&x) {
141127
// sycl::half may convert to _Float16, and we would try to instantiate
142128
// vec class with _Float16 DataType, which is not expected there. As
143129
// such, leave vector<half, N> as-is.
144-
using MatchingVec = vec<std::conditional_t<is_half_v<ElemTy>, ElemTy,
145-
decltype(convertToOpenCLType(
146-
std::declval<ElemTy>()))>,
147-
no_ref::size()>;
130+
using MatchingVec =
131+
vec<std::conditional_t<std::is_same_v<ElemTy, half>, ElemTy,
132+
decltype(convertToOpenCLType(
133+
std::declval<ElemTy>()))>,
134+
no_ref::size()>;
148135
#ifdef __SYCL_DEVICE_ONLY__
149136
return sycl::bit_cast<typename MatchingVec::vector_t>(x);
150137
#else
@@ -160,11 +147,11 @@ template <typename T> auto convertToOpenCLType(T &&x) {
160147
fixed_width_unsigned<sizeof(no_ref)>>;
161148
static_assert(sizeof(OpenCLType) == sizeof(T));
162149
return static_cast<OpenCLType>(x);
163-
} else if constexpr (is_half_v<no_ref>) {
150+
} else if constexpr (std::is_same_v<no_ref, half>) {
164151
using OpenCLType = sycl::detail::half_impl::BIsRepresentationT;
165152
static_assert(sizeof(OpenCLType) == sizeof(T));
166153
return static_cast<OpenCLType>(x);
167-
} else if constexpr (is_bfloat16_v<no_ref>) {
154+
} else if constexpr (std::is_same_v<no_ref, ext::oneapi::bfloat16>) {
168155
// On host, don't interpret BF16 as uint16.
169156
#ifdef __SYCL_DEVICE_ONLY__
170157
using OpenCLType = sycl::ext::oneapi::detail::Bfloat16StorageT;

sycl/include/sycl/vector.hpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ struct ScalarConversionOperatorMixIn<Vec, T, N, std::enable_if_t<N == 1>> {
119119
operator T() const { return (*static_cast<const Vec *>(this))[0]; }
120120
};
121121

122+
template <typename T>
123+
inline constexpr bool is_fundamental_or_half_or_bfloat16 =
124+
std::is_fundamental_v<T> || std::is_same_v<std::remove_const_t<T>, half> ||
125+
std::is_same_v<std::remove_const_t<T>, ext::oneapi::bfloat16>;
126+
122127
} // namespace detail
123128

124129
///////////////////////// class sycl::vec /////////////////////////
@@ -288,10 +293,8 @@ class __SYCL_EBO vec
288293
// when NumElements == 1. The template prevents implicit conversion from
289294
// vec<_, 1> to DataT.
290295
template <typename Ty = DataT>
291-
typename std::enable_if_t<
292-
std::is_fundamental_v<Ty> ||
293-
detail::is_half_or_bf16_v<typename std::remove_const_t<Ty>>,
294-
vec &>
296+
typename std::enable_if_t<detail::is_fundamental_or_half_or_bfloat16<Ty>,
297+
vec &>
295298
operator=(const DataT &Rhs) {
296299
*this = vec{Rhs};
297300
return *this;
@@ -626,16 +629,14 @@ class SwizzleOp {
626629
1 != IdxNum && SwizzleOp::getNumElements() == IdxNum, T>;
627630

628631
template <typename T>
629-
using EnableIfScalarType = typename std::enable_if_t<
630-
std::is_convertible_v<DataT, T> &&
631-
(std::is_fundamental_v<T> ||
632-
detail::is_half_or_bf16_v<typename std::remove_const_t<T>>)>;
632+
using EnableIfScalarType =
633+
typename std::enable_if_t<std::is_convertible_v<DataT, T> &&
634+
detail::is_fundamental_or_half_or_bfloat16<T>>;
633635

634636
template <typename T>
635-
using EnableIfNoScalarType = typename std::enable_if_t<
636-
!std::is_convertible_v<DataT, T> ||
637-
!(std::is_fundamental_v<T> ||
638-
detail::is_half_or_bf16_v<typename std::remove_const_t<T>>)>;
637+
using EnableIfNoScalarType =
638+
typename std::enable_if_t<!std::is_convertible_v<DataT, T> ||
639+
!detail::is_fundamental_or_half_or_bfloat16<T>>;
639640

640641
template <int... Indices>
641642
using Swizzle =

sycl/source/detail/device_info.hpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,14 +850,96 @@ struct get_device_info_impl<
850850
matrix_type::sint32, matrix_type::sint32},
851851
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
852852
matrix_type::fp32, matrix_type::fp32},
853+
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
854+
matrix_type::fp16, matrix_type::fp32},
855+
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
856+
matrix_type::fp32, matrix_type::fp16},
857+
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
858+
matrix_type::fp16, matrix_type::fp16},
859+
{0, 0, 0, 16, 16, 16, matrix_type::fp16, matrix_type::fp16,
860+
matrix_type::fp32, matrix_type::fp16},
861+
{0, 0, 0, 16, 16, 16, matrix_type::fp16, matrix_type::fp16,
862+
matrix_type::fp16, matrix_type::fp16},
863+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
864+
matrix_type::fp32, matrix_type::fp32},
865+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
866+
matrix_type::fp16, matrix_type::fp32},
867+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
868+
matrix_type::fp32, matrix_type::fp16},
869+
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
870+
matrix_type::fp16, matrix_type::fp16},
871+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
872+
matrix_type::fp32, matrix_type::fp32},
873+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
874+
matrix_type::fp16, matrix_type::fp32},
875+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
876+
matrix_type::fp32, matrix_type::bf16},
877+
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
878+
matrix_type::fp16, matrix_type::fp16},
879+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
880+
matrix_type::fp32, matrix_type::fp32},
881+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
882+
matrix_type::fp16, matrix_type::fp32},
883+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
884+
matrix_type::fp32, matrix_type::fp16},
885+
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
886+
matrix_type::fp16, matrix_type::fp16},
887+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
888+
matrix_type::fp32, matrix_type::fp32},
889+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
890+
matrix_type::fp16, matrix_type::fp32},
891+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
892+
matrix_type::fp32, matrix_type::fp16},
893+
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
894+
matrix_type::fp16, matrix_type::fp16},
895+
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
896+
matrix_type::bf16, matrix_type::bf16},
897+
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
898+
matrix_type::fp32, matrix_type::bf16},
899+
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
900+
matrix_type::bf16, matrix_type::fp32},
853901
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
854902
matrix_type::fp32, matrix_type::fp32},
855903
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
856904
matrix_type::fp32, matrix_type::fp32},
905+
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
906+
matrix_type::bf16, matrix_type::fp32},
907+
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
908+
matrix_type::fp32, matrix_type::bf16},
909+
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
910+
matrix_type::bf16, matrix_type::bf16},
857911
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
858912
matrix_type::fp32, matrix_type::fp32},
913+
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
914+
matrix_type::bf16, matrix_type::fp32},
915+
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
916+
matrix_type::fp32, matrix_type::bf16},
917+
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
918+
matrix_type::bf16, matrix_type::bf16},
859919
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
860920
matrix_type::fp32, matrix_type::fp32},
921+
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
922+
matrix_type::bf16, matrix_type::fp32},
923+
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
924+
matrix_type::fp32, matrix_type::bf16},
925+
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
926+
matrix_type::bf16, matrix_type::bf16},
927+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
928+
matrix_type::fp32, matrix_type::fp32},
929+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
930+
matrix_type::bf16, matrix_type::fp32},
931+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
932+
matrix_type::fp32, matrix_type::bf16},
933+
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
934+
matrix_type::bf16, matrix_type::bf16},
935+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
936+
matrix_type::fp32, matrix_type::fp32},
937+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
938+
matrix_type::bf16, matrix_type::fp32},
939+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
940+
matrix_type::fp32, matrix_type::bf16},
941+
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
942+
matrix_type::bf16, matrix_type::bf16},
861943
{8, 0, 0, 0, 16, 8, matrix_type::tf32, matrix_type::tf32,
862944
matrix_type::fp32, matrix_type::fp32},
863945
};

sycl/test-e2e/Matrix/common.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,17 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
6363
if constexpr (std::is_same_v<Ta, bfloat16> &&
6464
std::is_same_v<Tc, float>)
6565
acc += make_fp32(va[i]) * make_fp32(vb[i]);
66+
else if constexpr (std::is_same_v<Ta, sycl::half> &&
67+
std::is_same_v<Tc, float>)
68+
acc += (float)va[i] * (float)vb[i];
6669
else if constexpr (std::is_same_v<Ta, float> &&
6770
std::is_same_v<Tc, float> ||
6871
std::is_integral_v<Ta> && std::is_integral_v<Tc> ||
72+
(std::is_same_v<Ta, bfloat16> ||
73+
std::is_same_v<Ta, sycl::half>) ||
6974
(std::is_same_v<Ta, double> &&
7075
std::is_same_v<Tc, double>))
7176
acc += va[i] * vb[i];
72-
else if constexpr (std::is_same_v<Ta, sycl::half> &&
73-
std::is_same_v<Tc, float>)
74-
acc += (float)va[i] * (float)vb[i];
7577
else
7678
assert(false && "Unsupported type in matrix_multiply_ref.");
7779
}

0 commit comments

Comments
 (0)