Skip to content

Commit aa0886f

Browse files
YuriPlyakhinigcbot
authored andcommitted
Improve tf32 dpas Prototype for A/B Matrix Sources
Update TF32 DPAS prototype to use float A and B types Short matrix A type did not allow pre-processing of the A matrix and did not work well with 2d block loads for u32 sources
1 parent 8504571 commit aa0886f

File tree

4 files changed

+41
-17
lines changed

4 files changed

+41
-17
lines changed

IGC/BiFModule/Implementation/IGCBiF_Intrinsics_Dpas.cl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,16 +339,16 @@ half8 __builtin_IB_sub_group16_fdpas_hf_hf_hf_hf_8_8 (half8 acc, short8 a,
339339

340340

341341
// tf32, rcount = 1, simd16
342-
float __builtin_IB_sub_group16_fdpas_f_f_tf32_tf32_8_1 (float acc, short a, int8 b) __attribute__((const));
342+
float __builtin_IB_sub_group16_fdpas_f_f_tf32_tf32_8_1 (float acc, float a, float8 b) __attribute__((const));
343343

344344
// tf32, rcount = 2, simd16
345-
float2 __builtin_IB_sub_group16_fdpas_f_f_tf32_tf32_8_2 (float2 acc, short2 a, int8 b) __attribute__((const));
345+
float2 __builtin_IB_sub_group16_fdpas_f_f_tf32_tf32_8_2 (float2 acc, float a, float8 b) __attribute__((const));
346346

347347
// tf32, rcount = 4, simd16
348-
float4 __builtin_IB_sub_group16_fdpas_f_f_tf32_tf32_8_4 (float4 acc, short4 a, int8 b) __attribute__((const));
348+
float4 __builtin_IB_sub_group16_fdpas_f_f_tf32_tf32_8_4 (float4 acc, float2 a, float8 b) __attribute__((const));
349349

350350
// tf32, rcount = 8, simd16
351-
float8 __builtin_IB_sub_group16_fdpas_f_f_tf32_tf32_8_8 (float8 acc, short8 a, int8 b) __attribute__((const));
351+
float8 __builtin_IB_sub_group16_fdpas_f_f_tf32_tf32_8_8 (float8 acc, float4 a, float8 b) __attribute__((const));
352352

353353

354354
//

IGC/BiFModule/Languages/OpenCL/IBiF_dpas.cl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,10 @@ DEFN_INTEL_CVT2( f32_to_bf16_packed, int16, float16, float16, 2fto2bf_16 )
483483
#ifdef cl_intel_subgroup_matrix_multiply_accumulate_tf32
484484
// PVC_B
485485

486-
DEFN_INTEL_SG16_FDPAS( tf32_tf32_matrix_mad_k8_f32, float, float, short, int8, fdpas_f_f_tf32_tf32_8_1 )
487-
DEFN_INTEL_SG16_FDPAS( tf32_tf32_matrix_mad_k8_f32, float2, float2, short2, int8, fdpas_f_f_tf32_tf32_8_2 )
488-
DEFN_INTEL_SG16_FDPAS( tf32_tf32_matrix_mad_k8_f32, float4, float4, short4, int8, fdpas_f_f_tf32_tf32_8_4 )
489-
DEFN_INTEL_SG16_FDPAS( tf32_tf32_matrix_mad_k8_f32, float8, float8, short8, int8, fdpas_f_f_tf32_tf32_8_8 )
486+
DEFN_INTEL_SG16_FDPAS( tf32_tf32_matrix_mad_k8_f32, float, float, float, float8, fdpas_f_f_tf32_tf32_8_1 )
487+
DEFN_INTEL_SG16_FDPAS( tf32_tf32_matrix_mad_k8_f32, float2, float2, float, float8, fdpas_f_f_tf32_tf32_8_2 )
488+
DEFN_INTEL_SG16_FDPAS( tf32_tf32_matrix_mad_k8_f32, float4, float4, float2, float8, fdpas_f_f_tf32_tf32_8_4 )
489+
DEFN_INTEL_SG16_FDPAS( tf32_tf32_matrix_mad_k8_f32, float8, float8, float4, float8, fdpas_f_f_tf32_tf32_8_8 )
490490

491491
DEFN_INTEL_CVT( f32_to_tf32, int, float, ftotf32_1 )
492492
DEFN_INTEL_CVT( f32_to_tf32, int2, float2, ftotf32_2 )

IGC/BiFModule/Languages/OpenCL/PreRelease/opencl_cth_pre_release.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,11 +2380,22 @@ int16 __attribute__((overloadable)) intel_convert_f32_to_bf16_packed(float16 a,
23802380

23812381
#ifdef cl_intel_subgroup_matrix_multiply_accumulate_tf32
23822382

2383-
// A: half of tfloat32 B: tfloat32 ACC: float DST: float
2384-
float __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(short a, int8 b, float acc);
2385-
float2 __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(short2 a, int8 b, float2 acc);
2386-
float4 __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(short4 a, int8 b, float4 acc);
2387-
float8 __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(short8 a, int8 b, float8 acc);
2383+
// A: tf32, even rows in lower 8 SIMD channels, odd rows in upper 8 SIMD channels
2384+
// B: tf32
2385+
// ACC: float
2386+
// DST: float
2387+
2388+
// M = 1, K = 8, N = 16, upper 8 channels of a ignored
2389+
float __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(float a, float8 b, float acc);
2390+
2391+
// M = 2, K = 8, N = 16, all channels of a are used
2392+
float2 __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(float a, float8 b, float2 acc);
2393+
2394+
// M = 4, K = 8, N = 16
2395+
float4 __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(float2 a, float8 b, float4 acc);
2396+
2397+
// M = 8, K = 8, N = 16
2398+
float8 __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(float4 a, float8 b, float8 acc);
23882399

23892400
// Conversions
23902401
int __attribute__((overloadable)) intel_convert_f32_to_tf32(float source);

IGC/Compiler/Optimizer/OpenCLPasses/DpasFuncs/DpasFuncsResolution.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,16 +421,21 @@ void DpasFuncsResolution::visitCallInst(CallInst& CI)
421421
IGC_ASSERT_MESSAGE(ACC_nelts == RC, "ICE: dpas intrinsic has mismatched vector sizes of arguments!");
422422
IGC_ASSERT_MESSAGE(B_nelts == SD, "ICE: dpas intrinsic has mismatched vector sizes of arguments!");
423423
IGC_ASSERT_MESSAGE(precOk, "ICE: dpas's A and B have illegal type combination!");
424-
IGC_ASSERT_MESSAGE(B_BaseTy->isIntegerTy(32), "ICE: dpas's arg B shall have base type int32!");
425-
IGC_ASSERT_MESSAGE(RC == (IsDpasw ? 2 * A_nelts : A_nelts), "ICE: dpas's arg A has wrong element size!");
424+
IGC_ASSERT_MESSAGE(B_BaseTy->isIntegerTy(32) || (PB == PrecisionType::TF32 && B_BaseTy->isFloatTy()),
425+
"ICE: dpas's arg B shall have base type int32 or float!");
426+
IGC_ASSERT_MESSAGE((RC == (IsDpasw ? 2 * A_nelts : A_nelts) ||
427+
(PA == PrecisionType::TF32 && (RC == 2 * A_nelts))),
428+
"ICE: dpas's arg A has wrong element size!");
426429

427430
uint32_t AbitsPerDepth = 32;
428431
if (m_pCtx->platform.hasExecSize16DPAS())
429432
{
430433
AbitsPerDepth = AbitsPerDepth / 2;
431434
}
432435

433-
IGC_ASSERT_MESSAGE(A_BaseTy->isIntegerTy(AbitsPerDepth), "ICE: dpas intrinsic's A has wrong base type!");
436+
IGC_ASSERT_MESSAGE(A_BaseTy->isIntegerTy(AbitsPerDepth) ||
437+
(PA == PrecisionType::TF32 && A_BaseTy->isFloatTy()),
438+
"ICE: dpas intrinsic's A has wrong base type!");
434439
if (PA == PrecisionType::TF32)
435440
{
436441
if (!(DstTy == DSTACC_FLOAT && AccTy == DSTACC_FLOAT))
@@ -458,7 +463,15 @@ void DpasFuncsResolution::visitCallInst(CallInst& CI)
458463
Value* args[8];
459464
args[0] = CI.getArgOperand(0);
460465
args[1] = CI.getArgOperand(1);
461-
args[2] = CI.getArgOperand(2);
466+
467+
Value* B = CI.getArgOperand(2);
468+
Type* BTy = B->getType();
469+
if (FixedVectorType *BVecTy = dyn_cast<FixedVectorType>(BTy); BVecTy && BTy->getScalarType()->isFloatTy()) {
470+
B = CastInst::Create(Instruction::CastOps::BitCast, B,
471+
FixedVectorType::get(intTy, (unsigned) BVecTy->getNumElements()),
472+
B->getName() + ".cast", &CI);
473+
}
474+
args[2] = B;
462475

463476
args[3] = ConstantInt::get(intTy, PA);
464477
args[4] = ConstantInt::get(intTy, PB);

0 commit comments

Comments
 (0)