2222#endif
2323
2424#ifdef __SYCL_DEVICE_ONLY__
25+
26+ #if (SYCL_EXT_ONEAPI_MATRIX_VERSION > 1)
27+ #define JOINT_MATRIX_INTEL (T, R, C, L, S, U ) \
28+ __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U>
29+ #else
30+ #define JOINT_MATRIX_INTEL (T, R, C, L, S, U ) \
31+ __spv::__spirv_JointMatrixINTEL<T, R, C, L, S>
32+ #endif // SYCL_EXT_ONEAPI_MATRIX_VERSION
33+
2534template <typename T, std::size_t R, std::size_t C,
2635 __spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
2736 __spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
2837 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
29- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *
38+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *
3039__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
3140 __spv::MatrixLayout Layout = L,
3241 __spv::Scope::Flag Sc = S, int MemOperand = 0 );
@@ -36,7 +45,7 @@ template <typename T, std::size_t R, std::size_t C,
3645 __spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
3746 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
3847extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL (
39- T *Ptr, __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *Object,
48+ T *Ptr, JOINT_MATRIX_INTEL( T, R, C, L, S, U) *Object,
4049 std::size_t Stride, __spv::MatrixLayout Layout = L,
4150 __spv::Scope::Flag Sc = S, int MemOperand = 0);
4251
@@ -48,11 +57,11 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
4857 __spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
4958 __spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
5059 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
51- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T2, M, N, LC, S, UC> *
60+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T2, M, N, LC, S, UC) *
5261__spirv_JointMatrixMadINTEL(
53- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
54- __spv::__spirv_JointMatrixINTEL< T1, K, N, LB, S, UB> *B,
55- __spv::__spirv_JointMatrixINTEL< T2, M, N, LC, S, UC> *C,
62+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
63+ JOINT_MATRIX_INTEL( T1, K, N, LB, S, UB) *B,
64+ JOINT_MATRIX_INTEL( T2, M, N, LC, S, UC) *C,
5665 __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
5766
5867template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -63,11 +72,11 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
6372 __spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
6473 __spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
6574 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
66- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3 , M, N, LC, S, UC> *
75+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL (T2 , M, N, LC, S, UC) *
6776__spirv_JointMatrixUUMadINTEL(
68- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
69- __spv::__spirv_JointMatrixINTEL< T2, K, N, LB, S, UB> *B,
70- __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *C,
77+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
78+ JOINT_MATRIX_INTEL( T2, K, N, LB, S, UB) *B,
79+ JOINT_MATRIX_INTEL( T3, M, N, LC, S, UC) *C,
7180 __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
7281
7382template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -78,11 +87,11 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
7887 __spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
7988 __spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
8089 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
81- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *
90+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T3, M, N, LC, S, UC) *
8291__spirv_JointMatrixUSMadINTEL(
83- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
84- __spv::__spirv_JointMatrixINTEL< T2, K, N, LB, S, UB> *B,
85- __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *C,
92+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
93+ JOINT_MATRIX_INTEL( T2, K, N, LB, S, UB) *B,
94+ JOINT_MATRIX_INTEL( T3, M, N, LC, S, UC) *C,
8695 __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
8796
8897template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -93,38 +102,42 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
93102 __spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
94103 __spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
95104 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
96- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *
105+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T3, M, N, LC, S, UC) *
97106__spirv_JointMatrixSUMadINTEL(
98- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
99- __spv::__spirv_JointMatrixINTEL< T2, K, N, LB, S, UB> *B,
100- __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *C,
107+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
108+ JOINT_MATRIX_INTEL( T2, K, N, LB, S, UB) *B,
109+ JOINT_MATRIX_INTEL( T3, M, N, LC, S, UC) *C,
101110 __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
102111
103112template <typename T, std::size_t R, std::size_t C,
104113 __spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
105114 __spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
106115 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
107- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *
116+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *
108117__spirv_CompositeConstruct(const T v);
109118
110- template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
111- __spv::MatrixLayout L,
119+ template <typename T, std::size_t R, std::size_t C,
120+ __spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
121+ __spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
112122 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
113123extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL (
114- __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *);
124+ JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *);
115125
116- template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
117- __spv::MatrixLayout L,
126+ template <typename T, std::size_t R, std::size_t C,
127+ __spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
128+ __spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
118129 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
119130extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic (
120- __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *, size_t i);
131+ JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *, size_t i);
121132
122- template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
123- __spv::MatrixLayout L,
133+ template <typename T, std::size_t R, std::size_t C,
134+ __spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
135+ __spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
124136 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
125- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *
126- __spirv_VectorInsertDynamic (__spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *,
137+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *
138+ __spirv_VectorInsertDynamic(JOINT_MATRIX_INTEL( T, R, C, L, S, U) *,
127139 T val, size_t i);
140+ #undef JOINT_MATRIX_INTEL
128141
129142#ifndef __SPIRV_BUILTIN_DECLARATIONS__
130143#error \
0 commit comments