@@ -125,6 +125,113 @@ inline constexpr bool is_fundamental_or_half_or_bfloat16 =
125125 std::is_fundamental_v<T> || std::is_same_v<std::remove_const_t <T>, half> ||
126126 std::is_same_v<std::remove_const_t <T>, ext::oneapi::bfloat16>;
127127
128+ // Proposed SYCL specification changes have sycl::vec having different ctors
129+ // available based on the number of elements. Without C++20's concepts we'll
130+ // have to use partial specialization to represent that. This is a helper to do
131+ // that. An alternative could be to have different specializations of the
132+ // `sycl::vec` itself but then we'd need to outline all the common interfaces to
133+ // re-use them.
134+ //
135+ // Note: the functional changes haven't been implemented yet, we've split
136+ // vec_base in advance as a way to make changes easier to review/verify.
137+ //
138+ // Another note: `vector_t` is going to be removed, so corresponding ctor was
139+ // kept inside `sycl::vec` to have all `vector_t` functionality in a single
140+ // place.
141+ template <typename DataT, int NumElements> class vec_base {
142+ // https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#memory-layout-and-alignment
143+ // It is required by the SPEC to align vec<DataT, 3> with vec<DataT, 4>.
144+ static constexpr size_t AdjustedNum = (NumElements == 3 ) ? 4 : NumElements;
145+ // This represent type of underlying value. There should be only one field
146+ // in the class, so vec<float, 16> should be equal to float16 in memory.
147+ using DataType = std::array<DataT, AdjustedNum>;
148+
149+ protected:
150+ // fields
151+ // Alignment is the same as size, to a maximum size of 64. SPEC requires
152+ // "The elements of an instance of the SYCL vec class template are stored
153+ // in memory sequentially and contiguously and are aligned to the size of
154+ // the element type in bytes multiplied by the number of elements."
155+ static constexpr int alignment = (std::min)((size_t )64 , sizeof (DataType));
156+ alignas (alignment) DataType m_Data;
157+
158+ template <size_t ... Is>
159+ constexpr vec_base (const std::array<DataT, NumElements> &Arr,
160+ std::index_sequence<Is...>)
161+ : m_Data{Arr[Is]...} {}
162+
163+ template <typename CtorArgTy>
164+ static constexpr bool AllowArgTypeInVariadicCtor = []() constexpr {
165+ if constexpr (std::is_convertible_v<CtorArgTy, DataT>) {
166+ return true ;
167+ } else if constexpr (is_vec_or_swizzle_v<CtorArgTy>) {
168+ if constexpr (CtorArgTy::size () == 1 &&
169+ std::is_convertible_v<typename CtorArgTy::element_type,
170+ DataT>) {
171+ // Temporary workaround because swizzle's `operator DataT` is a
172+ // template.
173+ return true ;
174+ }
175+ return std::is_same_v<typename CtorArgTy::element_type, DataT>;
176+ } else {
177+ return false ;
178+ }
179+ }();
180+
181+ template <typename T> static constexpr int num_elements () {
182+ if constexpr (is_vec_or_swizzle_v<T>)
183+ return T::size ();
184+ else
185+ return 1 ;
186+ }
187+
188+ // Utility trait for creating an std::array from an vector argument.
189+ template <typename DataT_, typename T> class FlattenVecArg {
190+ template <std::size_t ... Is>
191+ static constexpr auto helper (const T &V, std::index_sequence<Is...>) {
192+ // FIXME: Swizzle's `operator[]` for expression trees seems to be broken
193+ // and returns values of the underlying vector of some of the operands. On
194+ // the other hand, `getValue()` gives correct results. This can be changed
195+ // to using `operator[]` once the bug is fixed.
196+ if constexpr (is_swizzle_v<T>)
197+ return std::array{static_cast <DataT_>(V.getValue (Is))...};
198+ else
199+ return std::array{static_cast <DataT_>(V[Is])...};
200+ }
201+
202+ public:
203+ constexpr auto operator ()(const T &A) const {
204+ if constexpr (is_vec_or_swizzle_v<T>) {
205+ return helper (A, std::make_index_sequence<T ::size ()>());
206+ } else {
207+ return std::array{static_cast <DataT_>(A)};
208+ }
209+ }
210+ };
211+
212+ // Alias for shortening the vec arguments to array converter.
213+ template <typename DataT_, typename ... ArgTN>
214+ using VecArgArrayCreator = ArrayCreator<DataT_, FlattenVecArg, ArgTN...>;
215+
216+ public:
217+ constexpr vec_base () = default;
218+ constexpr vec_base (const vec_base &) = default;
219+ constexpr vec_base (vec_base &&) = default;
220+ constexpr vec_base &operator =(const vec_base &) = default ;
221+ constexpr vec_base &operator =(vec_base &&) = default ;
222+
223+ explicit constexpr vec_base (const DataT &arg)
224+ : vec_base(RepeatValue<NumElements>(arg),
225+ std::make_index_sequence<NumElements>()) {}
226+
227+ template <typename ... argTN,
228+ typename = std::enable_if_t <
229+ ((AllowArgTypeInVariadicCtor<argTN> && ...)) &&
230+ ((num_elements<argTN>() + ...)) == NumElements>>
231+ constexpr vec_base (const argTN &...args)
232+ : vec_base{VecArgArrayCreator<DataT, argTN...>::Create (args...),
233+ std::make_index_sequence<NumElements>()} {}
234+ };
128235} // namespace detail
129236
130237// /////////////////////// class sycl::vec /////////////////////////
@@ -136,7 +243,9 @@ class __SYCL_EBO vec
136243 public detail::ApplyIf<
137244 NumElements == 1 ,
138245 detail::ScalarConversionOperatorMixIn<vec<DataT, NumElements>>>,
139- public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>> {
246+ public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>>,
247+ // Keep it last to simplify ABI layout test:
248+ public detail::vec_base<DataT, NumElements> {
140249 static_assert (std::is_same_v<DataT, std::remove_cv_t <DataT>>,
141250 " DataT must be cv-unqualified" );
142251
@@ -145,13 +254,7 @@ class __SYCL_EBO vec
145254 " or 16 are supported" );
146255 static_assert (sizeof (bool ) == sizeof (uint8_t ), " bool size is not 1 byte" );
147256
148- // https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#memory-layout-and-alignment
149- // It is required by the SPEC to align vec<DataT, 3> with vec<DataT, 4>.
150- static constexpr size_t AdjustedNum = (NumElements == 3 ) ? 4 : NumElements;
151-
152- // This represent type of underlying value. There should be only one field
153- // in the class, so vec<float, 16> should be equal to float16 in memory.
154- using DataType = std::array<DataT, AdjustedNum>;
257+ using Base = detail::vec_base<DataT, NumElements>;
155258
156259#ifdef __SYCL_DEVICE_ONLY__
157260 using element_type_for_vector_t = typename detail::map_type<
@@ -184,48 +287,19 @@ class __SYCL_EBO vec
184287 typename vector_t_ = vector_t ,
185288 typename = typename std::enable_if_t <std::is_same_v<vector_t_, vector_t >>>
186289 constexpr vec (vector_t_ openclVector) {
187- m_Data = sycl::bit_cast<DataType >(openclVector);
290+ this -> m_Data = sycl::bit_cast<decltype ( this -> m_Data ) >(openclVector);
188291 }
189292
190293 /* @SYCL2020
191294 * Available only when: compiled for the device.
192295 * Converts this SYCL vec instance to the underlying backend-native vector
193296 * type defined by vector_t.
194297 */
195- operator vector_t () const { return sycl::bit_cast<vector_t >(m_Data); }
298+ operator vector_t () const { return sycl::bit_cast<vector_t >(this -> m_Data ); }
196299
197300private:
198301#endif // __SYCL_DEVICE_ONLY__
199302
200- // Utility trait for creating an std::array from an vector argument.
201- template <typename DataT_, typename T> class FlattenVecArg {
202- template <std::size_t ... Is>
203- static constexpr auto helper (const T &V, std::index_sequence<Is...>) {
204- // FIXME: Swizzle's `operator[]` for expression trees seems to be broken
205- // and returns values of the underlying vector of some of the operands. On
206- // the other hand, `getValue()` gives correct results. This can be changed
207- // to using `operator[]` once the bug is fixed.
208- if constexpr (detail::is_swizzle_v<T>)
209- return std::array{static_cast <DataT_>(V.getValue (Is))...};
210- else
211- return std::array{static_cast <DataT_>(V[Is])...};
212- }
213-
214- public:
215- constexpr auto operator ()(const T &A) const {
216- if constexpr (detail::is_vec_or_swizzle_v<T>) {
217- return helper (A, std::make_index_sequence<T ::size ()>());
218- } else {
219- return std::array{static_cast <DataT_>(A)};
220- }
221- }
222- };
223-
224- // Alias for shortening the vec arguments to array converter.
225- template <typename DataT_, typename ... ArgTN>
226- using VecArgArrayCreator =
227- detail::ArrayCreator<DataT_, FlattenVecArg, ArgTN...>;
228-
229303 template <int ... Indexes>
230304 using Swizzle =
231305 detail::SwizzleOp<vec, detail::GetOp<DataT>, detail::GetOp<DataT>,
@@ -236,27 +310,6 @@ class __SYCL_EBO vec
236310 detail::SwizzleOp<const vec, detail::GetOp<DataT>, detail::GetOp<DataT>,
237311 detail::GetOp, Indexes...>;
238312
239- // Shortcuts for args validation in vec(const argTN &... args) ctor.
240- template <typename CtorArgTy>
241- static constexpr bool AllowArgTypeInVariadicCtor = []() constexpr {
242- // FIXME: This logic implements the behavior of the previous implementation.
243- if constexpr (detail::is_vec_or_swizzle_v<CtorArgTy>) {
244- if constexpr (CtorArgTy::size () == 1 )
245- return std::is_convertible_v<typename CtorArgTy::element_type, DataT>;
246- else
247- return std::is_same_v<typename CtorArgTy::element_type, DataT>;
248- } else {
249- return std::is_convertible_v<CtorArgTy, DataT>;
250- }
251- }();
252-
253- template <typename T> static constexpr int num_elements () {
254- if constexpr (detail::is_vec_or_swizzle_v<T>)
255- return T::size ();
256- else
257- return 1 ;
258- }
259-
260313 // Element type for relational operator return value.
261314 using rel_t = detail::fixed_width_signed<sizeof (DataT)>;
262315
@@ -266,35 +319,13 @@ class __SYCL_EBO vec
266319 using element_type = DataT;
267320 using value_type = DataT;
268321
269- /* ***************** Constructors **************/
270- vec () = default;
271- constexpr vec (const vec &Rhs) = default;
272- constexpr vec (vec &&Rhs) = default;
273-
274- private:
275- // Implementation detail for the next public ctor.
276- template <size_t ... Is>
277- constexpr vec (const std::array<DataT, NumElements> &Arr,
278- std::index_sequence<Is...>)
279- : m_Data{Arr[Is]...} {}
280-
281- public:
282- explicit constexpr vec (const DataT &arg)
283- : vec{detail::RepeatValue<NumElements>(arg),
284- std::make_index_sequence<NumElements>()} {}
285-
286- // Constructor from values of base type or vec of base type. Checks that
287- // base types are match and that the NumElements == sum of lengths of args.
288- template <typename ... argTN,
289- typename = std::enable_if_t <
290- ((AllowArgTypeInVariadicCtor<argTN> && ...)) &&
291- ((num_elements<argTN>() + ...)) == NumElements>>
292- constexpr vec (const argTN &...args)
293- : vec{VecArgArrayCreator<DataT, argTN...>::Create (args...),
294- std::make_index_sequence<NumElements>()} {}
322+ using Base::Base;
323+ constexpr vec (const vec &) = default;
324+ constexpr vec (vec &&) = default;
295325
296326 /* ***************** Assignment Operators **************/
297- constexpr vec &operator =(const vec &Rhs) = default ;
327+ constexpr vec &operator =(const vec &) = default ;
328+ constexpr vec &operator =(vec &&) = default ;
298329
299330 // Template required to prevent ambiguous overload with the copy assignment
300331 // when NumElements == 1. The template prevents implicit conversion from
@@ -322,7 +353,7 @@ class __SYCL_EBO vec
322353 __SYCL2020_DEPRECATED (
323354 " get_size() is deprecated, please use byte_size() instead" )
324355 static constexpr size_t get_size() { return byte_size (); }
325- static constexpr size_t byte_size () noexcept { return sizeof (m_Data ); }
356+ static constexpr size_t byte_size () noexcept { return sizeof (Base ); }
326357
327358private:
328359 // getValue should be able to operate on different underlying
@@ -339,10 +370,10 @@ class __SYCL_EBO vec
339370
340371#ifdef __SYCL_DEVICE_ONLY__
341372 if constexpr (std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>)
342- return sycl::bit_cast<RetType>(m_Data[Index]);
373+ return sycl::bit_cast<RetType>(this -> m_Data [Index]);
343374 else
344375#endif
345- return static_cast <RetType>(m_Data[Index]);
376+ return static_cast <RetType>(this -> m_Data [Index]);
346377 }
347378
348379public:
@@ -362,14 +393,14 @@ class __SYCL_EBO vec
362393 return this ;
363394 }
364395
365- const DataT &operator [](int i) const { return m_Data[i]; }
396+ const DataT &operator [](int i) const { return this -> m_Data [i]; }
366397
367- DataT &operator [](int i) { return m_Data[i]; }
398+ DataT &operator [](int i) { return this -> m_Data [i]; }
368399
369400 template <access::address_space Space, access::decorated DecorateAddress>
370401 void load (size_t Offset, multi_ptr<const DataT, Space, DecorateAddress> Ptr) {
371402 for (int I = 0 ; I < NumElements; I++) {
372- m_Data[I] = *multi_ptr<const DataT, Space, DecorateAddress>(
403+ this -> m_Data [I] = *multi_ptr<const DataT, Space, DecorateAddress>(
373404 Ptr + Offset * NumElements + I);
374405 }
375406 }
@@ -392,15 +423,15 @@ class __SYCL_EBO vec
392423 }
393424 void load (size_t Offset, const DataT *Ptr) {
394425 for (int I = 0 ; I < NumElements; ++I)
395- m_Data[I] = Ptr[Offset * NumElements + I];
426+ this -> m_Data [I] = Ptr[Offset * NumElements + I];
396427 }
397428
398429 template <access::address_space Space, access::decorated DecorateAddress>
399430 void store (size_t Offset,
400431 multi_ptr<DataT, Space, DecorateAddress> Ptr) const {
401432 for (int I = 0 ; I < NumElements; I++) {
402433 *multi_ptr<DataT, Space, DecorateAddress>(Ptr + Offset * NumElements +
403- I) = m_Data[I];
434+ I) = this -> m_Data [I];
404435 }
405436 }
406437 template <int Dimensions, access::mode Mode,
@@ -416,18 +447,9 @@ class __SYCL_EBO vec
416447 }
417448 void store (size_t Offset, DataT *Ptr) const {
418449 for (int I = 0 ; I < NumElements; ++I)
419- Ptr[Offset * NumElements + I] = m_Data[I];
450+ Ptr[Offset * NumElements + I] = this -> m_Data [I];
420451 }
421452
422- private:
423- // fields
424- // Alignment is the same as size, to a maximum size of 64. SPEC requires
425- // "The elements of an instance of the SYCL vec class template are stored
426- // in memory sequentially and contiguously and are aligned to the size of
427- // the element type in bytes multiplied by the number of elements."
428- static constexpr int alignment = (std::min)((size_t )64 , sizeof (DataType));
429- alignas (alignment) DataType m_Data;
430-
431453 // friends
432454 template <typename T1, typename T2, typename T3, template <typename > class T4 ,
433455 int ... T5>
@@ -1272,6 +1294,7 @@ class SwizzleOp : public detail::NamedSwizzlesMixinBoth<
12721294
12731295 // friends
12741296 template <typename T1, int T2> friend class sycl ::vec;
1297+ template <typename , int > friend class sycl ::detail::vec_base;
12751298
12761299 template <typename T1, typename T2, typename T3, template <typename > class T4 ,
12771300 int ... T5>
0 commit comments