Skip to content

Commit fc234c1

Browse files
mateuszchudykigcbot
authored andcommitted
Changes in code.
1 parent a636237 commit fc234c1

File tree

3 files changed

+100
-100
lines changed

3 files changed

+100
-100
lines changed

IGC/BiFModule/Implementation/IGCBiF_Intrinsics_Dpas.cl

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ DPAS_DEPTH_8( __builtin_IB_idpas, int, short8, int8, 2, 4, 1 )
7272
DPAS_DEPTH_8( __builtin_IB_idpas, int, int8, int4, 4, 2, 1 )
7373
DPAS_DEPTH_8( __builtin_IB_idpas, int, short8, int4, 2, 2, 1 )
7474

75-
float __builtin_IB_fdpas_bf_bf_8_1 (float acc, uint8 a, uint8 b) __attribute__((const));
76-
float __builtin_IB_fdpas_hf_hf_8_1 (float acc, uint8 a, uint8 b) __attribute__((const));
75+
float __builtin_IB_fdpas_bf_bf_8_1 (float acc, int8 a, int8 b) __attribute__((const));
76+
float __builtin_IB_fdpas_hf_hf_8_1 (float acc, int8 a, int8 b) __attribute__((const));
7777

7878
//
7979
// Sub group version of dpas. (suffix: <pa-bits>_<pb-bits>_<depth>_<rcount>)
@@ -135,16 +135,16 @@ float4 __builtin_IB_sub_group_fdpas_8_4 (float4 acc, int4 a, int8 b) __attribute
135135
float8 __builtin_IB_sub_group_fdpas_8_8 (float8 acc, int8 a, int8 b) __attribute__((const)); // deprecated
136136

137137
// bfloat16
138-
float __builtin_IB_sub_group_fdpas_bf_bf_8_1 (float acc, uint a, uint8 b) __attribute__((const));
139-
float2 __builtin_IB_sub_group_fdpas_bf_bf_8_2 (float2 acc, uint2 a, uint8 b) __attribute__((const));
140-
float4 __builtin_IB_sub_group_fdpas_bf_bf_8_4 (float4 acc, uint4 a, uint8 b) __attribute__((const));
141-
float8 __builtin_IB_sub_group_fdpas_bf_bf_8_8 (float8 acc, uint8 a, uint8 b) __attribute__((const));
138+
float __builtin_IB_sub_group_fdpas_bf_bf_8_1 (float acc, int a, int8 b) __attribute__((const));
139+
float2 __builtin_IB_sub_group_fdpas_bf_bf_8_2 (float2 acc, int2 a, int8 b) __attribute__((const));
140+
float4 __builtin_IB_sub_group_fdpas_bf_bf_8_4 (float4 acc, int4 a, int8 b) __attribute__((const));
141+
float8 __builtin_IB_sub_group_fdpas_bf_bf_8_8 (float8 acc, int8 a, int8 b) __attribute__((const));
142142

143143
// half
144-
float __builtin_IB_sub_group_fdpas_hf_hf_8_1 (float acc, uint a, uint8 b) __attribute__((const));
145-
float2 __builtin_IB_sub_group_fdpas_hf_hf_8_2 (float2 acc, uint2 a, uint8 b) __attribute__((const));
146-
float4 __builtin_IB_sub_group_fdpas_hf_hf_8_4 (float4 acc, uint4 a, uint8 b) __attribute__((const));
147-
float8 __builtin_IB_sub_group_fdpas_hf_hf_8_8 (float8 acc, uint8 a, uint8 b) __attribute__((const));
144+
float __builtin_IB_sub_group_fdpas_hf_hf_8_1 (float acc, int a, int8 b) __attribute__((const));
145+
float2 __builtin_IB_sub_group_fdpas_hf_hf_8_2 (float2 acc, int2 a, int8 b) __attribute__((const));
146+
float4 __builtin_IB_sub_group_fdpas_hf_hf_8_4 (float4 acc, int4 a, int8 b) __attribute__((const));
147+
float8 __builtin_IB_sub_group_fdpas_hf_hf_8_8 (float8 acc, int8 a, int8 b) __attribute__((const));
148148

149149
//
150150
// dpasw: 'a' size is the half of the dpas version.
@@ -189,14 +189,14 @@ DPAS_DEPTH_8( __builtin_IB_sub_group_idpasw, int8, int4, int4, 4, 2, 8 )
189189
DPAS_DEPTH_8( __builtin_IB_sub_group_idpasw, int8, short4, int4, 2, 2, 8 )
190190

191191
// bfloat16
192-
float2 __builtin_IB_sub_group_fdpasw_bf_bf_8_2 (float2 acc, uint a, uint8 b) __attribute__((const));
193-
float4 __builtin_IB_sub_group_fdpasw_bf_bf_8_4 (float4 acc, uint2 a, uint8 b) __attribute__((const));
194-
float8 __builtin_IB_sub_group_fdpasw_bf_bf_8_8 (float8 acc, uint4 a, uint8 b) __attribute__((const));
192+
float2 __builtin_IB_sub_group_fdpasw_bf_bf_8_2 (float2 acc, int a, int8 b) __attribute__((const));
193+
float4 __builtin_IB_sub_group_fdpasw_bf_bf_8_4 (float4 acc, int2 a, int8 b) __attribute__((const));
194+
float8 __builtin_IB_sub_group_fdpasw_bf_bf_8_8 (float8 acc, int4 a, int8 b) __attribute__((const));
195195

196196
// half
197-
float2 __builtin_IB_sub_group_fdpasw_hf_hf_8_2 (float2 acc, uint a, uint8 b) __attribute__((const));
198-
float4 __builtin_IB_sub_group_fdpasw_hf_hf_8_4 (float4 acc, uint2 a, uint8 b) __attribute__((const));
199-
float8 __builtin_IB_sub_group_fdpasw_hf_hf_8_8 (float8 acc, uint4 a, uint8 b) __attribute__((const));
197+
float2 __builtin_IB_sub_group_fdpasw_hf_hf_8_2 (float2 acc, int a, int8 b) __attribute__((const));
198+
float4 __builtin_IB_sub_group_fdpasw_hf_hf_8_4 (float4 acc, int2 a, int8 b) __attribute__((const));
199+
float8 __builtin_IB_sub_group_fdpasw_hf_hf_8_8 (float8 acc, int4 a, int8 b) __attribute__((const));
200200

201201
// Names of SIMD16 dpas builtin functions are in the form:
202202
// __builtin_IB_sub_group16_idpas_<a's precision>_<b's precision>_<depth>_<repeatCount>
@@ -289,52 +289,52 @@ DPAS_DEPTH_8( __builtin_IB_sub_group16_idpas, int8, char8, int4, 2, 2, 8 )
289289
//
290290

291291
// bfloat16, rcount = 1, simd16
292-
float __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1 (float acc, ushort a, uint8 b) __attribute__((const));
293-
ushort __builtin_IB_sub_group16_fdpas_bf_f_bf_bf_8_1 (float acc, ushort a, uint8 b) __attribute__((const));
294-
float __builtin_IB_sub_group16_fdpas_f_bf_bf_bf_8_1 (ushort acc, ushort a, uint8 b) __attribute__((const));
295-
ushort __builtin_IB_sub_group16_fdpas_bf_bf_bf_bf_8_1 (ushort acc, ushort a, uint8 b) __attribute__((const));
292+
float __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1 (float acc, short a, int8 b) __attribute__((const));
293+
short __builtin_IB_sub_group16_fdpas_bf_f_bf_bf_8_1 (float acc, short a, int8 b) __attribute__((const));
294+
float __builtin_IB_sub_group16_fdpas_f_bf_bf_bf_8_1 (short acc, short a, int8 b) __attribute__((const));
295+
short __builtin_IB_sub_group16_fdpas_bf_bf_bf_bf_8_1 (short acc, short a, int8 b) __attribute__((const));
296296

297297
// bfloat16, rcount = 2, simd16
298-
float2 __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_2 (float2 acc, ushort2 a, uint8 b) __attribute__((const));
299-
ushort2 __builtin_IB_sub_group16_fdpas_bf_f_bf_bf_8_2 (float2 acc, ushort2 a, uint8 b) __attribute__((const));
300-
float2 __builtin_IB_sub_group16_fdpas_f_bf_bf_bf_8_2 (ushort2 acc, ushort2 a, uint8 b) __attribute__((const));
301-
ushort2 __builtin_IB_sub_group16_fdpas_bf_bf_bf_bf_8_2 (ushort2 acc, ushort2 a, uint8 b) __attribute__((const));
298+
float2 __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_2 (float2 acc, short2 a, int8 b) __attribute__((const));
299+
short2 __builtin_IB_sub_group16_fdpas_bf_f_bf_bf_8_2 (float2 acc, short2 a, int8 b) __attribute__((const));
300+
float2 __builtin_IB_sub_group16_fdpas_f_bf_bf_bf_8_2 (short2 acc, short2 a, int8 b) __attribute__((const));
301+
short2 __builtin_IB_sub_group16_fdpas_bf_bf_bf_bf_8_2 (short2 acc, short2 a, int8 b) __attribute__((const));
302302

303303
// bfloat16, rcount = 4, simd16
304-
float4 __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_4 (float4 acc, ushort4 a, uint8 b) __attribute__((const));
305-
ushort4 __builtin_IB_sub_group16_fdpas_bf_f_bf_bf_8_4 (float4 acc, ushort4 a, uint8 b) __attribute__((const));
306-
float4 __builtin_IB_sub_group16_fdpas_f_bf_bf_bf_8_4 (ushort4 acc, ushort4 a, uint8 b) __attribute__((const));
307-
ushort4 __builtin_IB_sub_group16_fdpas_bf_bf_bf_bf_8_4 (ushort4 acc, ushort4 a, uint8 b) __attribute__((const));
304+
float4 __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_4 (float4 acc, short4 a, int8 b) __attribute__((const));
305+
short4 __builtin_IB_sub_group16_fdpas_bf_f_bf_bf_8_4 (float4 acc, short4 a, int8 b) __attribute__((const));
306+
float4 __builtin_IB_sub_group16_fdpas_f_bf_bf_bf_8_4 (short4 acc, short4 a, int8 b) __attribute__((const));
307+
short4 __builtin_IB_sub_group16_fdpas_bf_bf_bf_bf_8_4 (short4 acc, short4 a, int8 b) __attribute__((const));
308308

309309
// bfloat16, rcount = 8, simd16
310-
float8 __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8 (float8 acc, ushort8 a, uint8 b) __attribute__((const));
311-
ushort8 __builtin_IB_sub_group16_fdpas_bf_f_bf_bf_8_8 (float8 acc, ushort8 a, uint8 b) __attribute__((const));
312-
float8 __builtin_IB_sub_group16_fdpas_f_bf_bf_bf_8_8 (ushort8 acc, ushort8 a, uint8 b) __attribute__((const));
313-
ushort8 __builtin_IB_sub_group16_fdpas_bf_bf_bf_bf_8_8 (ushort8 acc, ushort8 a, uint8 b) __attribute__((const));
310+
float8 __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8 (float8 acc, short8 a, int8 b) __attribute__((const));
311+
short8 __builtin_IB_sub_group16_fdpas_bf_f_bf_bf_8_8 (float8 acc, short8 a, int8 b) __attribute__((const));
312+
float8 __builtin_IB_sub_group16_fdpas_f_bf_bf_bf_8_8 (short8 acc, short8 a, int8 b) __attribute__((const));
313+
short8 __builtin_IB_sub_group16_fdpas_bf_bf_bf_bf_8_8 (short8 acc, short8 a, int8 b) __attribute__((const));
314314

315315
// half, rcount = 1, simd16
316-
float __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_1 (float acc, ushort a, uint8 b) __attribute__((const));
317-
half __builtin_IB_sub_group16_fdpas_hf_f_hf_hf_8_1 (float acc, ushort a, uint8 b) __attribute__((const));
318-
float __builtin_IB_sub_group16_fdpas_f_hf_hf_hf_8_1 (half acc, ushort a, uint8 b) __attribute__((const));
319-
half __builtin_IB_sub_group16_fdpas_hf_hf_hf_hf_8_1 (half acc, ushort a, uint8 b) __attribute__((const));
316+
float __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_1 (float acc, short a, int8 b) __attribute__((const));
317+
half __builtin_IB_sub_group16_fdpas_hf_f_hf_hf_8_1 (float acc, short a, int8 b) __attribute__((const));
318+
float __builtin_IB_sub_group16_fdpas_f_hf_hf_hf_8_1 (half acc, short a, int8 b) __attribute__((const));
319+
half __builtin_IB_sub_group16_fdpas_hf_hf_hf_hf_8_1 (half acc, short a, int8 b) __attribute__((const));
320320

321321
// half, rcount = 2, simd16
322-
float2 __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_2 (float2 acc, ushort2 a, uint8 b) __attribute__((const));
323-
half2 __builtin_IB_sub_group16_fdpas_hf_f_hf_hf_8_2 (float2 acc, ushort2 a, uint8 b) __attribute__((const));
324-
float2 __builtin_IB_sub_group16_fdpas_f_hf_hf_hf_8_2 (half2 acc, ushort2 a, uint8 b) __attribute__((const));
325-
half2 __builtin_IB_sub_group16_fdpas_hf_hf_hf_hf_8_2 (half2 acc, ushort2 a, uint8 b) __attribute__((const));
322+
float2 __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_2 (float2 acc, short2 a, int8 b) __attribute__((const));
323+
half2 __builtin_IB_sub_group16_fdpas_hf_f_hf_hf_8_2 (float2 acc, short2 a, int8 b) __attribute__((const));
324+
float2 __builtin_IB_sub_group16_fdpas_f_hf_hf_hf_8_2 (half2 acc, short2 a, int8 b) __attribute__((const));
325+
half2 __builtin_IB_sub_group16_fdpas_hf_hf_hf_hf_8_2 (half2 acc, short2 a, int8 b) __attribute__((const));
326326

327327
// half, rcount = 4, simd16
328-
float4 __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_4 (float4 acc, ushort4 a, uint8 b) __attribute__((const));
329-
half4 __builtin_IB_sub_group16_fdpas_hf_f_hf_hf_8_4 (float4 acc, ushort4 a, uint8 b) __attribute__((const));
330-
float4 __builtin_IB_sub_group16_fdpas_f_hf_hf_hf_8_4 (half4 acc, ushort4 a, uint8 b) __attribute__((const));
331-
half4 __builtin_IB_sub_group16_fdpas_hf_hf_hf_hf_8_4 (half4 acc, ushort4 a, uint8 b) __attribute__((const));
328+
float4 __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_4 (float4 acc, short4 a, int8 b) __attribute__((const));
329+
half4 __builtin_IB_sub_group16_fdpas_hf_f_hf_hf_8_4 (float4 acc, short4 a, int8 b) __attribute__((const));
330+
float4 __builtin_IB_sub_group16_fdpas_f_hf_hf_hf_8_4 (half4 acc, short4 a, int8 b) __attribute__((const));
331+
half4 __builtin_IB_sub_group16_fdpas_hf_hf_hf_hf_8_4 (half4 acc, short4 a, int8 b) __attribute__((const));
332332

333333
// half, rcount = 8, simd16
334-
float8 __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_8 (float8 acc, ushort8 a, uint8 b) __attribute__((const));
335-
half8 __builtin_IB_sub_group16_fdpas_hf_f_hf_hf_8_8 (float8 acc, ushort8 a, uint8 b) __attribute__((const));
336-
float8 __builtin_IB_sub_group16_fdpas_f_hf_hf_hf_8_8 (half8 acc, ushort8 a, uint8 b) __attribute__((const));
337-
half8 __builtin_IB_sub_group16_fdpas_hf_hf_hf_hf_8_8 (half8 acc, ushort8 a, uint8 b) __attribute__((const));
334+
float8 __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_8 (float8 acc, short8 a, int8 b) __attribute__((const));
335+
half8 __builtin_IB_sub_group16_fdpas_hf_f_hf_hf_8_8 (float8 acc, short8 a, int8 b) __attribute__((const));
336+
float8 __builtin_IB_sub_group16_fdpas_f_hf_hf_hf_8_8 (half8 acc, short8 a, int8 b) __attribute__((const));
337+
half8 __builtin_IB_sub_group16_fdpas_hf_hf_hf_hf_8_8 (half8 acc, short8 a, int8 b) __attribute__((const));
338338

339339

340340
// tf32, rcount = 1, simd16

IGC/BiFModule/Languages/OpenCL/IBiF_dpas.cl

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -238,16 +238,16 @@ DEFN_INTEL_SG_IDPAS( u2_u2_matrix_mad_k64, int8, ushort8, uint4, idpas_u2_u2_8_8
238238

239239

240240
// bfloat16: both a and b are 2 bfloat16.
241-
DEFN_INTEL_SG_FDPAS( bf16_bf16_matrix_mad_k16, float, uint, uint8, fdpas_bf_bf_8_1 )
242-
DEFN_INTEL_SG_FDPAS( bf16_bf16_matrix_mad_k16, float2, uint2, uint8, fdpas_bf_bf_8_2 )
243-
DEFN_INTEL_SG_FDPAS( bf16_bf16_matrix_mad_k16, float4, uint4, uint8, fdpas_bf_bf_8_4 )
244-
DEFN_INTEL_SG_FDPAS( bf16_bf16_matrix_mad_k16, float8, uint8, uint8, fdpas_bf_bf_8_8 )
241+
DEFN_INTEL_SG_FDPAS( bf16_bf16_matrix_mad_k16, float, int, int8, fdpas_bf_bf_8_1 )
242+
DEFN_INTEL_SG_FDPAS( bf16_bf16_matrix_mad_k16, float2, int2, int8, fdpas_bf_bf_8_2 )
243+
DEFN_INTEL_SG_FDPAS( bf16_bf16_matrix_mad_k16, float4, int4, int8, fdpas_bf_bf_8_4 )
244+
DEFN_INTEL_SG_FDPAS( bf16_bf16_matrix_mad_k16, float8, int8, int8, fdpas_bf_bf_8_8 )
245245

246246
// half: both a and b are 2 half.
247-
DEFN_INTEL_SG_FDPAS( f16_f16_matrix_mad_k16, float, uint, uint8, fdpas_hf_hf_8_1 )
248-
DEFN_INTEL_SG_FDPAS( f16_f16_matrix_mad_k16, float2, uint2, uint8, fdpas_hf_hf_8_2 )
249-
DEFN_INTEL_SG_FDPAS( f16_f16_matrix_mad_k16, float4, uint4, uint8, fdpas_hf_hf_8_4 )
250-
DEFN_INTEL_SG_FDPAS( f16_f16_matrix_mad_k16, float8, uint8, uint8, fdpas_hf_hf_8_8 )
247+
DEFN_INTEL_SG_FDPAS( f16_f16_matrix_mad_k16, float, int, int8, fdpas_hf_hf_8_1 )
248+
DEFN_INTEL_SG_FDPAS( f16_f16_matrix_mad_k16, float2, int2, int8, fdpas_hf_hf_8_2 )
249+
DEFN_INTEL_SG_FDPAS( f16_f16_matrix_mad_k16, float4, int4, int8, fdpas_hf_hf_8_4 )
250+
DEFN_INTEL_SG_FDPAS( f16_f16_matrix_mad_k16, float8, int8, int8, fdpas_hf_hf_8_8 )
251251

252252

253253
//// PVC : simd16 ////
@@ -418,10 +418,10 @@ DEFN_INTEL_SG16_IDPAS( u2_u2_matrix_mad_k64, int8, uchar8, uint4, idpas_u2_u2_8_
418418

419419

420420
// bfloat16: both a and b are 2 bfloat16.
421-
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float, float, ushort, uint8, fdpas_f_f_bf_bf_8_1 )
422-
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float2, float2, ushort2, uint8, fdpas_f_f_bf_bf_8_2 )
423-
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float4, float4, ushort4, uint8, fdpas_f_f_bf_bf_8_4 )
424-
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float8, float8, ushort8, uint8, fdpas_f_f_bf_bf_8_8 )
421+
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float, float, short, int8, fdpas_f_f_bf_bf_8_1 )
422+
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float2, float2, short2, int8, fdpas_f_f_bf_bf_8_2 )
423+
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float4, float4, short4, int8, fdpas_f_f_bf_bf_8_4 )
424+
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float8, float8, short8, int8, fdpas_f_f_bf_bf_8_8 )
425425
//DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, short, float, short, int8, fdpas_bf_f_bf_bf_8_1 )
426426
//DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, short2, float2, short2, int8, fdpas_bf_f_bf_bf_8_2 )
427427
//DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, short4, float4, short4, int8, fdpas_bf_f_bf_bf_8_4 )
@@ -430,16 +430,16 @@ DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float8, float8, ushort8, uint8
430430
//DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float2, short2, short2, int8, fdpas_f_bf_bf_bf_8_2 )
431431
//DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float4, short4, short4, int8, fdpas_f_bf_bf_bf_8_4 )
432432
//DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, float8, short8, short8, int8, fdpas_f_bf_bf_bf_8_8 )
433-
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, ushort, ushort, ushort, uint8, fdpas_bf_bf_bf_bf_8_1 )
434-
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, ushort2, ushort2, ushort2, uint8, fdpas_bf_bf_bf_bf_8_2 )
435-
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, ushort4, ushort4, ushort4, uint8, fdpas_bf_bf_bf_bf_8_4 )
436-
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, ushort8, ushort8, ushort8, uint8, fdpas_bf_bf_bf_bf_8_8 )
433+
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, short, short, short, int8, fdpas_bf_bf_bf_bf_8_1 )
434+
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, short2, short2, short2, int8, fdpas_bf_bf_bf_bf_8_2 )
435+
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, short4, short4, short4, int8, fdpas_bf_bf_bf_bf_8_4 )
436+
DEFN_INTEL_SG16_FDPAS( bf16_bf16_matrix_mad_k16, short8, short8, short8, int8, fdpas_bf_bf_bf_bf_8_8 )
437437

438438
// half: both a and b are 2 half.
439-
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float, float, ushort, uint8, fdpas_f_f_hf_hf_8_1 )
440-
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float2, float2, ushort2, uint8, fdpas_f_f_hf_hf_8_2 )
441-
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float4, float4, ushort4, uint8, fdpas_f_f_hf_hf_8_4 )
442-
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float8, float8, ushort8, uint8, fdpas_f_f_hf_hf_8_8 )
439+
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float, float, short, int8, fdpas_f_f_hf_hf_8_1 )
440+
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float2, float2, short2, int8, fdpas_f_f_hf_hf_8_2 )
441+
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float4, float4, short4, int8, fdpas_f_f_hf_hf_8_4 )
442+
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float8, float8, short8, int8, fdpas_f_f_hf_hf_8_8 )
443443

444444
#ifdef cl_khr_fp16
445445

@@ -451,10 +451,10 @@ DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float8, float8, ushort8, uint8,
451451
//DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float2, half2, short2, int8, fdpas_f_hf_hf_hf_8_2 )
452452
//DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float4, half4, short4, int8, fdpas_f_hf_hf_hf_8_4 )
453453
//DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, float8, half8, short8, int8, fdpas_f_hf_hf_hf_8_8 )
454-
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, half, half, ushort, uint8, fdpas_hf_hf_hf_hf_8_1 )
455-
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, half2, half2, ushort2, uint8, fdpas_hf_hf_hf_hf_8_2 )
456-
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, half4, half4, ushort4, uint8, fdpas_hf_hf_hf_hf_8_4 )
457-
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, half8, half8, ushort8, uint8, fdpas_hf_hf_hf_hf_8_8 )
454+
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, half, half, short, int8, fdpas_hf_hf_hf_hf_8_1 )
455+
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, half2, half2, short2, int8, fdpas_hf_hf_hf_hf_8_2 )
456+
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, half4, half4, short4, int8, fdpas_hf_hf_hf_hf_8_4 )
457+
DEFN_INTEL_SG16_FDPAS( f16_f16_matrix_mad_k16, half8, half8, short8, int8, fdpas_hf_hf_hf_hf_8_8 )
458458

459459
#endif // cl_khr_fp16
460460

@@ -660,13 +660,13 @@ DEFN_INTEL_SG_IDPAS( u2_u2_split_matrix_mad_k64, int8, ushort4, uint4, idpasw_u2
660660

661661

662662
// bfloat16: both a and b are 2 bfloat16.
663-
DEFN_INTEL_SG_FDPAS( bf16_bf16_split_matrix_mad_k16, float2, uint, uint8, fdpasw_bf_bf_8_2 )
664-
DEFN_INTEL_SG_FDPAS( bf16_bf16_split_matrix_mad_k16, float4, uint2, uint8, fdpasw_bf_bf_8_4 )
665-
DEFN_INTEL_SG_FDPAS( bf16_bf16_split_matrix_mad_k16, float8, uint4, uint8, fdpasw_bf_bf_8_8 )
663+
DEFN_INTEL_SG_FDPAS( bf16_bf16_split_matrix_mad_k16, float2, int, int8, fdpasw_bf_bf_8_2 )
664+
DEFN_INTEL_SG_FDPAS( bf16_bf16_split_matrix_mad_k16, float4, int2, int8, fdpasw_bf_bf_8_4 )
665+
DEFN_INTEL_SG_FDPAS( bf16_bf16_split_matrix_mad_k16, float8, int4, int8, fdpasw_bf_bf_8_8 )
666666

667667
// half: both a and b are 2 half.
668-
DEFN_INTEL_SG_FDPAS( f16_f16_split_matrix_mad_k16, float2, uint, uint8, fdpasw_hf_hf_8_2 )
669-
DEFN_INTEL_SG_FDPAS( f16_f16_split_matrix_mad_k16, float4, uint2, uint8, fdpasw_hf_hf_8_4 )
670-
DEFN_INTEL_SG_FDPAS( f16_f16_split_matrix_mad_k16, float8, uint4, uint8, fdpasw_hf_hf_8_8 )
668+
DEFN_INTEL_SG_FDPAS( f16_f16_split_matrix_mad_k16, float2, int, int8, fdpasw_hf_hf_8_2 )
669+
DEFN_INTEL_SG_FDPAS( f16_f16_split_matrix_mad_k16, float4, int2, int8, fdpasw_hf_hf_8_4 )
670+
DEFN_INTEL_SG_FDPAS( f16_f16_split_matrix_mad_k16, float8, int4, int8, fdpasw_hf_hf_8_8 )
671671

672672
#endif // cl_intel_subgroup_split_matrix_multiply_accumulate

0 commit comments

Comments
 (0)