diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 6b23b0d6b6f78..d564384c906dc 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -180,7 +180,10 @@ class SPIRVInstructionSelector : public InstructionSelector { MachineInstr &I, unsigned Opcode) const; bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType, - MachineInstr &I) const; + MachineInstr &I, bool Signed) const; + + bool selectIntegerDotExpansion(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; template bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType, @@ -1705,11 +1708,28 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } -// Since pre-1.6 SPIRV has no integer dot implementation, -// expand by piecewise multiplying and adding the results bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg, const SPIRVType *ResType, - MachineInstr &I) const { + MachineInstr &I, + bool Signed) const { + assert(I.getNumOperands() == 4); + assert(I.getOperand(2).isReg()); + assert(I.getOperand(3).isReg()); + MachineBasicBlock &BB = *I.getParent(); + + auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot; + return BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(2).getReg()) + .addUse(I.getOperand(3).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +// Since pre-1.6 SPIRV has no integer dot implementation, +// expand by piecewise multiplying and adding the results +bool SPIRVInstructionSelector::selectIntegerDotExpansion( + Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { assert(I.getNumOperands() == 4); assert(I.getOperand(2).isReg()); assert(I.getOperand(3).isReg()); @@ -2755,7 +2775,11 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectFloatDot(ResVReg, ResType, I); case Intrinsic::spv_udot: case Intrinsic::spv_sdot: - return selectIntegerDot(ResVReg, ResType, I); + if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) || + STI.isAtLeastSPIRVVer(VersionTuple(1, 6))) + return selectIntegerDot(ResVReg, ResType, I, + /*Signed=*/IID == Intrinsic::spv_sdot); + return selectIntegerDotExpansion(ResVReg, ResType, I); case Intrinsic::spv_dot4add_i8packed: if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) || STI.isAtLeastSPIRVVer(VersionTuple(1, 6))) diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index ea64f3d1c722c..d9f928eb90640 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1029,21 +1029,27 @@ static void AddDotProductRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::DotProduct); const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); - const MachineInstr *InstrPtr = &MI; - assert(MI.getOperand(1).isReg() && "Unexpected operand in dot"); - - Register TypeReg = InstrPtr->getOperand(1).getReg(); - SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); + assert(MI.getOperand(2).isReg() && "Unexpected operand in dot"); + // We do not consider what the previous instruction is. This is just used + // to get the input register and to check the type. + const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg()); + assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input"); + Register InputReg = Input->getOperand(1).getReg(); + + SPIRVType *TypeDef = MRI.getVRegDef(InputReg); if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { assert(TypeDef->getOperand(1).getImm() == 32); Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked); } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) { SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg()); assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt); - auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8 - ? SPIRV::Capability::DotProductInput4x8Bit - : SPIRV::Capability::DotProductInputAll; - Reqs.addCapability(Capability); + if (ScalarTypeDef->getOperand(1).getImm() == 8) { + assert(TypeDef->getOperand(2).getImm() == 4 && + "Dot operand of 8-bit integer type requires 4 components"); + Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit); + } else { + Reqs.addCapability(SPIRV::Capability::DotProductInputAll); + } } } diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll index 22b6ed6bdfcbc..8acad352cdc29 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll @@ -1,8 +1,20 @@ -; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s -; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} +; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP +; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT +; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %} +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %} +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %} ; Make sure dxil operation function calls for dot are generated for int/uint vectors. +; CHECK-DAG: OpCapability Int8 +; CHECK-DOT-DAG: OpCapability DotProduct +; CHECK-DOT-DAG: OpCapability DotProductInputAll +; CHECK-DOT-DAG: OpCapability DotProductInput4x8Bit +; CHECK-EXT-DAG: OpExtension "SPV_KHR_integer_dot_product" + +; CHECK-DAG: %[[#int_8:]] = OpTypeInt 8 +; CHECK-DAG: %[[#vec4_int_8:]] = OpTypeVector %[[#int_8]] 4 ; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16 ; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2 ; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3 @@ -11,14 +23,32 @@ ; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64 ; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2 +define noundef i8 @dot_int8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]] + +; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_8]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]] + %dot = call i8 @llvm.spv.sdot.v4i8(<4 x i8> %a, <4 x i8> %b) + ret i8 %dot +} + define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) { entry: ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]] ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]] -; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]] -; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0 -; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1 -; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]] + +; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_16]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]] %dot = call i16 @llvm.spv.sdot.v3i16(<2 x i16> %a, <2 x i16> %b) ret i16 %dot } @@ -27,28 +57,49 @@ define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) { entry: ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]] ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]] -; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]] -; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0 -; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1 -; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]] -; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2 -; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]] -; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3 -; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]] + +; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_32]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]] +; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2 +; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]] +; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3 +; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]] %dot = call i32 @llvm.spv.sdot.v4i32(<4 x i32> %a, <4 x i32> %b) ret i32 %dot } +define noundef i8 @dot_uint8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]] + +; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_8]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]] + %dot = call i8 @llvm.spv.udot.v4i8(<4 x i8> %a, <4 x i8> %b) + ret i8 %dot +} + define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) { entry: ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]] ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]] -; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]] -; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0 -; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1 -; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]] -; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2 -; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]] + +; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_16]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]] +; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2 +; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]] %dot = call i16 @llvm.spv.udot.v3i16(<3 x i16> %a, <3 x i16> %b) ret i16 %dot } @@ -57,14 +108,17 @@ define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) { entry: ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]] ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]] -; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]] -; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0 -; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1 -; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]] -; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2 -; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]] -; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3 -; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]] + +; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_32]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]] +; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2 +; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]] +; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3 +; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]] %dot = call i32 @llvm.spv.udot.v4i32(<4 x i32> %a, <4 x i32> %b) ret i32 %dot } @@ -73,16 +127,21 @@ define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) { entry: ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]] ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]] -; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]] -; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0 -; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1 -; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]] + +; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_64]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]] %dot = call i64 @llvm.spv.udot.v2i64(<2 x i64> %a, <2 x i64> %b) ret i64 %dot } +declare i8 @llvm.spv.sdot.v4i8(<4 x i8>, <4 x i8>) declare i16 @llvm.spv.sdot.v2i16(<2 x i16>, <2 x i16>) declare i32 @llvm.spv.sdot.v4i32(<4 x i32>, <4 x i32>) +declare i8 @llvm.spv.udot.v4i8(<4 x i8>, <4 x i8>) declare i16 @llvm.spv.udot.v3i32(<3 x i16>, <3 x i16>) declare i32 @llvm.spv.udot.v4i32(<4 x i32>, <4 x i32>) declare i64 @llvm.spv.udot.v2i64(<2 x i64>, <2 x i64>)