From 087fd8212e866468fc1e5375fc02a69ab3cc28a9 Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Tue, 5 Nov 2024 23:12:22 +0000 Subject: [PATCH 1/4] [SPIRV] Use `Op[S|U]Dot` when possible for integer dot product - use the new OpSDot/OpUDot instructions when capabilites allow in SPIRVInstructionSelector.cpp - correct functionality of capability check onto input operand and not return operand type in SPIRVModuleAnalysis.cpp - add test cases to demonstrate use case in idot.ll --- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 30 ++++- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 18 ++- .../CodeGen/SPIRV/hlsl-intrinsics/idot.ll | 124 +++++++++++++----- 3 files changed, 128 insertions(+), 44 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 414583aea91e6..4cb6f8e0af59c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -179,7 +179,10 @@ class SPIRVInstructionSelector : public InstructionSelector { MachineInstr &I, unsigned Opcode) const; bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType, - MachineInstr &I) const; + MachineInstr &I, bool Signed) const; + + bool selectIntegerDotExpansion(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; template bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType, @@ -1681,9 +1684,27 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + bool Signed) const { + assert(I.getNumOperands() == 4); + assert(I.getOperand(2).isReg()); + assert(I.getOperand(3).isReg()); + MachineBasicBlock &BB = *I.getParent(); + + auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot; + return BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(2).getReg()) + .addUse(I.getOperand(3).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + // Since pre-1.6 SPIRV has no integer dot implementation, // expand by piecewise multiplying and adding the results -bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg, +bool SPIRVInstructionSelector::selectIntegerDotExpansion(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { assert(I.getNumOperands() == 4); @@ -2681,7 +2702,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectFloatDot(ResVReg, ResType, I); case Intrinsic::spv_udot: case Intrinsic::spv_sdot: - return selectIntegerDot(ResVReg, ResType, I); + if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) || + STI.isAtLeastSPIRVVer(VersionTuple(1, 6))) + return selectIntegerDot(ResVReg, ResType, I, /*Signed=*/IID == Intrinsic::spv_sdot); + return selectIntegerDotExpansion(ResVReg, ResType, I); case Intrinsic::spv_dot4add_i8packed: if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) || STI.isAtLeastSPIRVVer(VersionTuple(1, 6))) diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index e8641b3a105de..4f30d7d83b7a7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1013,10 +1013,11 @@ static void AddDotProductRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::DotProduct); const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); - const MachineInstr *InstrPtr = &MI; - assert(MI.getOperand(1).isReg() && "Unexpected operand in dot"); + assert(MI.getOperand(2).isReg() && "Unexpected operand in dot"); + const MachineInstr *InputInstr = MRI.getVRegDef(MI.getOperand(2).getReg()); + assert(InputInstr->getOperand(1).isReg() && "Unexpected operand in dot input"); - Register TypeReg = InstrPtr->getOperand(1).getReg(); + Register TypeReg = InputInstr->getOperand(1).getReg(); SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { assert(TypeDef->getOperand(1).getImm() == 32); @@ -1024,10 +1025,13 @@ static void AddDotProductRequirements(const MachineInstr &MI, } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) { SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg()); assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt); - auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8 - ? SPIRV::Capability::DotProductInput4x8Bit - : SPIRV::Capability::DotProductInputAll; - Reqs.addCapability(Capability); + if (ScalarTypeDef->getOperand(1).getImm() == 8) { + assert(TypeDef->getOperand(2).getImm() == 4 + && "Dot operand of 8-bit integer type requires 4 components"); + Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit); + } else { + Reqs.addCapability(SPIRV::Capability::DotProductInputAll); + } } } diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll index 22b6ed6bdfcbc..b952cfe24a77d 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll @@ -1,8 +1,20 @@ -; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s -; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} +; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP +; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT +; RUN: llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %} +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %} +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %} ; Make sure dxil operation function calls for dot are generated for int/uint vectors. +; CHECK-DAG: OpCapability Int8 +; CHECK-DOT-DAG: OpCapability DotProduct +; CHECK-DOT-DAG: OpCapability DotProductInputAll +; CHECK-DOT-DAG: OpCapability DotProductInput4x8Bit +; CHECK-EXT-DAG: OpExtension "SPV_KHR_integer_dot_product" + +; CHECK-DAG: %[[#int_8:]] = OpTypeInt 8 +; CHECK-DAG: %[[#vec4_int_8:]] = OpTypeVector %[[#int_8]] 4 ; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16 ; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2 ; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3 @@ -11,14 +23,32 @@ ; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64 ; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2 +define noundef i8 @dot_int8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]] + +; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_8]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]] + %dot = call i8 @llvm.spv.sdot.v4i8(<4 x i8> %a, <4 x i8> %b) + ret i8 %dot +} + define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) { entry: ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]] ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]] -; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]] -; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0 -; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1 -; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]] + +; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_16]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]] %dot = call i16 @llvm.spv.sdot.v3i16(<2 x i16> %a, <2 x i16> %b) ret i16 %dot } @@ -27,28 +57,49 @@ define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) { entry: ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]] ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]] -; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]] -; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0 -; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1 -; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]] -; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2 -; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]] -; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3 -; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]] + +; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_32]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]] +; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2 +; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]] +; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3 +; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]] %dot = call i32 @llvm.spv.sdot.v4i32(<4 x i32> %a, <4 x i32> %b) ret i32 %dot } +define noundef i8 @dot_uint8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) { +entry: +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]] + +; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_8]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]] + %dot = call i8 @llvm.spv.udot.v4i8(<4 x i8> %a, <4 x i8> %b) + ret i8 %dot +} + define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) { entry: ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]] ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]] -; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]] -; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0 -; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1 -; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]] -; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2 -; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]] + +; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_16]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]] +; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2 +; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]] %dot = call i16 @llvm.spv.udot.v3i16(<3 x i16> %a, <3 x i16> %b) ret i16 %dot } @@ -57,32 +108,37 @@ define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) { entry: ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]] ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]] -; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]] -; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0 -; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1 -; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]] -; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2 -; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]] -; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3 -; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]] + +; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_32]] %[[#arg0]] %[[#arg1]] + +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]] +; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2 +; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]] +; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3 +; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]] %dot = call i32 @llvm.spv.udot.v4i32(<4 x i32> %a, <4 x i32> %b) ret i32 %dot } define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) { entry: -; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]] -; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]] -; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]] -; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0 -; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1 -; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]] +; CHECK-EXP: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]] +; CHECK-EXP: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]] +; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]] +; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0 +; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1 +; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]] %dot = call i64 @llvm.spv.udot.v2i64(<2 x i64> %a, <2 x i64> %b) ret i64 %dot } +declare i8 @llvm.spv.sdot.v4i8(<4 x i8>, <4 x i8>) declare i16 @llvm.spv.sdot.v2i16(<2 x i16>, <2 x i16>) declare i32 @llvm.spv.sdot.v4i32(<4 x i32>, <4 x i32>) +declare i8 @llvm.spv.udot.v4i8(<4 x i8>, <4 x i8>) declare i16 @llvm.spv.udot.v3i32(<3 x i16>, <3 x i16>) declare i32 @llvm.spv.udot.v4i32(<4 x i32>, <4 x i32>) declare i64 @llvm.spv.udot.v2i64(<2 x i64>, <2 x i64>) From ca9b5a0884fb3e9ac695ec894158981adab43aee Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Wed, 6 Nov 2024 00:03:57 +0000 Subject: [PATCH 2/4] clang format --- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 18 +++++++++--------- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 7 ++++--- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 4cb6f8e0af59c..7fd7ad5a34711 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1695,18 +1695,17 @@ bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg, auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot; return BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(I.getOperand(2).getReg()) - .addUse(I.getOperand(3).getReg()) - .constrainAllUses(TII, TRI, RBI); + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(2).getReg()) + .addUse(I.getOperand(3).getReg()) + .constrainAllUses(TII, TRI, RBI); } // Since pre-1.6 SPIRV has no integer dot implementation, // expand by piecewise multiplying and adding the results -bool SPIRVInstructionSelector::selectIntegerDotExpansion(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I) const { +bool SPIRVInstructionSelector::selectIntegerDotExpansion( + Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { assert(I.getNumOperands() == 4); assert(I.getOperand(2).isReg()); assert(I.getOperand(3).isReg()); @@ -2704,7 +2703,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_sdot: if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) || STI.isAtLeastSPIRVVer(VersionTuple(1, 6))) - return selectIntegerDot(ResVReg, ResType, I, /*Signed=*/IID == Intrinsic::spv_sdot); + return selectIntegerDot(ResVReg, ResType, I, + /*Signed=*/IID == Intrinsic::spv_sdot); return selectIntegerDotExpansion(ResVReg, ResType, I); case Intrinsic::spv_dot4add_i8packed: if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) || diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 4f30d7d83b7a7..63f2ee966030d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1015,7 +1015,8 @@ static void AddDotProductRequirements(const MachineInstr &MI, const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); assert(MI.getOperand(2).isReg() && "Unexpected operand in dot"); const MachineInstr *InputInstr = MRI.getVRegDef(MI.getOperand(2).getReg()); - assert(InputInstr->getOperand(1).isReg() && "Unexpected operand in dot input"); + assert(InputInstr->getOperand(1).isReg() && + "Unexpected operand in dot input"); Register TypeReg = InputInstr->getOperand(1).getReg(); SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); @@ -1026,8 +1027,8 @@ static void AddDotProductRequirements(const MachineInstr &MI, SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg()); assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt); if (ScalarTypeDef->getOperand(1).getImm() == 8) { - assert(TypeDef->getOperand(2).getImm() == 4 - && "Dot operand of 8-bit integer type requires 4 components"); + assert(TypeDef->getOperand(2).getImm() == 4 && + "Dot operand of 8-bit integer type requires 4 components"); Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit); } else { Reqs.addCapability(SPIRV::Capability::DotProductInputAll); From e52cf425c98be7c511f242244f39c3e2224d6126 Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Thu, 7 Nov 2024 23:05:48 +0000 Subject: [PATCH 3/4] self review: - fix spirv version - add missing check --- llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll index b952cfe24a77d..8acad352cdc29 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll @@ -1,9 +1,9 @@ ; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP ; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT -; RUN: llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT +; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %} ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %} -; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %} +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %} ; Make sure dxil operation function calls for dot are generated for int/uint vectors. @@ -125,8 +125,11 @@ entry: define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) { entry: -; CHECK-EXP: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]] -; CHECK-EXP: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]] +; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]] +; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]] + +; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_64]] %[[#arg0]] %[[#arg1]] + ; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]] ; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0 ; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1 From 0558a836b0d438967bb36df1424d41ef00a917a9 Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Wed, 20 Nov 2024 19:45:10 +0000 Subject: [PATCH 4/4] self-review: - add clarifying comment/code --- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 63f2ee966030d..765ff69ed3871 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1014,12 +1014,13 @@ static void AddDotProductRequirements(const MachineInstr &MI, const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); assert(MI.getOperand(2).isReg() && "Unexpected operand in dot"); - const MachineInstr *InputInstr = MRI.getVRegDef(MI.getOperand(2).getReg()); - assert(InputInstr->getOperand(1).isReg() && - "Unexpected operand in dot input"); + // We do not consider what the previous instruction is. This is just used + // to get the input register and to check the type. + const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg()); + assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input"); + Register InputReg = Input->getOperand(1).getReg(); - Register TypeReg = InputInstr->getOperand(1).getReg(); - SPIRVType *TypeDef = MRI.getVRegDef(TypeReg); + SPIRVType *TypeDef = MRI.getVRegDef(InputReg); if (TypeDef->getOpcode() == SPIRV::OpTypeInt) { assert(TypeDef->getOperand(1).getImm() == 32); Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);