Skip to content

Conversation

@inbelic
Copy link
Contributor

@inbelic inbelic commented Nov 6, 2024

- use the new OpSDot/OpUDot instructions when capabilites allow in SPIRVInstructionSelector.cpp
- correct functionality of capability check onto input operand and not return operand type in SPIRVModuleAnalysis.cpp

- add test cases to demonstrate use case in idot.ll

Resolves #114632

@inbelic inbelic marked this pull request as ready for review November 7, 2024 22:57
@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Finn Plummer (inbelic)

Changes
- correct functionality of capability check onto input operand and not return operand type in SPIRVModuleAnalysis.cpp

- add test cases to demonstrate use case in idot.ll

Resolves #114632


Full diff: https://github.com/llvm/llvm-project/pull/115095.diff

3 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+29-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+12-7)
  • (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll (+90-34)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7aa5f4f2b1a8f1..f892832057b033 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -162,7 +162,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
                            MachineInstr &I, unsigned Opcode) const;
 
   bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
-                        MachineInstr &I) const;
+                        MachineInstr &I, bool Signed) const;
+
+  bool selectIntegerDotExpansion(Register ResVReg, const SPIRVType *ResType,
+                                 MachineInstr &I) const;
 
   template <bool Signed>
   bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType,
@@ -1640,11 +1643,28 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
       .constrainAllUses(TII, TRI, RBI);
 }
 
-// Since pre-1.6 SPIRV has no integer dot implementation,
-// expand by piecewise multiplying and adding the results
 bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
                                                 const SPIRVType *ResType,
-                                                MachineInstr &I) const {
+                                                MachineInstr &I,
+                                                bool Signed) const {
+  assert(I.getNumOperands() == 4);
+  assert(I.getOperand(2).isReg());
+  assert(I.getOperand(3).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addUse(I.getOperand(2).getReg())
+      .addUse(I.getOperand(3).getReg())
+      .constrainAllUses(TII, TRI, RBI);
+}
+
+// Since pre-1.6 SPIRV has no integer dot implementation,
+// expand by piecewise multiplying and adding the results
+bool SPIRVInstructionSelector::selectIntegerDotExpansion(
+    Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
   assert(I.getNumOperands() == 4);
   assert(I.getOperand(2).isReg());
   assert(I.getOperand(3).isReg());
@@ -2640,7 +2660,11 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectFloatDot(ResVReg, ResType, I);
   case Intrinsic::spv_udot:
   case Intrinsic::spv_sdot:
-    return selectIntegerDot(ResVReg, ResType, I);
+    if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
+        STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
+      return selectIntegerDot(ResVReg, ResType, I,
+                              /*Signed=*/IID == Intrinsic::spv_sdot);
+    return selectIntegerDotExpansion(ResVReg, ResType, I);
   case Intrinsic::spv_dot4add_i8packed:
     if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
         STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index e8641b3a105dec..63f2ee966030d3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1013,10 +1013,12 @@ static void AddDotProductRequirements(const MachineInstr &MI,
   Reqs.addCapability(SPIRV::Capability::DotProduct);
 
   const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
-  const MachineInstr *InstrPtr = &MI;
-  assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");
+  assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
+  const MachineInstr *InputInstr = MRI.getVRegDef(MI.getOperand(2).getReg());
+  assert(InputInstr->getOperand(1).isReg() &&
+         "Unexpected operand in dot input");
 
-  Register TypeReg = InstrPtr->getOperand(1).getReg();
+  Register TypeReg = InputInstr->getOperand(1).getReg();
   SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
   if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
     assert(TypeDef->getOperand(1).getImm() == 32);
@@ -1024,10 +1026,13 @@ static void AddDotProductRequirements(const MachineInstr &MI,
   } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
     SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
     assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
-    auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
-                          ? SPIRV::Capability::DotProductInput4x8Bit
-                          : SPIRV::Capability::DotProductInputAll;
-    Reqs.addCapability(Capability);
+    if (ScalarTypeDef->getOperand(1).getImm() == 8) {
+      assert(TypeDef->getOperand(2).getImm() == 4 &&
+             "Dot operand of 8-bit integer type requires 4 components");
+      Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
+    } else {
+      Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
+    }
   }
 }
 
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
index 22b6ed6bdfcbc5..b952cfe24a77db 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
@@ -1,8 +1,20 @@
-; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
+; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
 
 ; Make sure dxil operation function calls for dot are generated for int/uint vectors.
 
+; CHECK-DAG: OpCapability Int8
+; CHECK-DOT-DAG: OpCapability DotProduct
+; CHECK-DOT-DAG: OpCapability DotProductInputAll
+; CHECK-DOT-DAG: OpCapability DotProductInput4x8Bit
+; CHECK-EXT-DAG: OpExtension "SPV_KHR_integer_dot_product"
+
+; CHECK-DAG: %[[#int_8:]] = OpTypeInt 8
+; CHECK-DAG: %[[#vec4_int_8:]] = OpTypeVector %[[#int_8]] 4
 ; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16
 ; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2
 ; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3
@@ -11,14 +23,32 @@
 ; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64
 ; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2
 
+define noundef i8 @dot_int8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]]
+
+; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_8]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]]
+  %dot = call i8 @llvm.spv.sdot.v4i8(<4 x i8> %a, <4 x i8> %b)
+  ret i8 %dot
+}
+
 define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
 entry:
 ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]]
 ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]]
-; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
-; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
-; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
-; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
+
+; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_16]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
   %dot = call i16 @llvm.spv.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
   ret i16 %dot
 }
@@ -27,28 +57,49 @@ define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
 entry:
 ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
 ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
-; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
-; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
-; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
-; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
-; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
-; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
-; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
-; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
+
+; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_32]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
+; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
+; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
+; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
+; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
   %dot = call i32 @llvm.spv.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
   ret i32 %dot
 }
 
+define noundef i8 @dot_uint8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) {
+entry:
+; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]]
+; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]]
+
+; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_8]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]]
+  %dot = call i8 @llvm.spv.udot.v4i8(<4 x i8> %a, <4 x i8> %b)
+  ret i8 %dot
+}
+
 define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
 entry:
 ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]]
 ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]]
-; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
-; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
-; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
-; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
-; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
-; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
+
+; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_16]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
+; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
+; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
   %dot = call i16 @llvm.spv.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
   ret i16 %dot
 }
@@ -57,32 +108,37 @@ define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
 entry:
 ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
 ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
-; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
-; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
-; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
-; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
-; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
-; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
-; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
-; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
+
+; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_32]] %[[#arg0]] %[[#arg1]]
+
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
+; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
+; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
+; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
+; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
   %dot = call i32 @llvm.spv.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
   ret i32 %dot
 }
 
 define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
 entry:
-; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
-; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
-; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
-; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
-; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
-; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
+; CHECK-EXP: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
+; CHECK-EXP: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
+; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
+; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
+; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
+; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
   %dot = call i64 @llvm.spv.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
   ret i64 %dot
 }
 
+declare i8 @llvm.spv.sdot.v4i8(<4 x i8>, <4 x i8>)
 declare i16 @llvm.spv.sdot.v2i16(<2 x i16>, <2 x i16>)
 declare i32 @llvm.spv.sdot.v4i32(<4 x i32>, <4 x i32>)
+declare i8 @llvm.spv.udot.v4i8(<4 x i8>, <4 x i8>)
 declare i16 @llvm.spv.udot.v3i32(<3 x i16>, <3 x i16>)
 declare i32 @llvm.spv.udot.v4i32(<4 x i32>, <4 x i32>)
 declare i64 @llvm.spv.udot.v2i64(<2 x i64>, <2 x i64>)

- use the new OpSDot/OpUDot instructions when capabilites allow in
SPIRVInstructionSelector.cpp
- correct functionality of capability check onto input operand and not
return operand type in SPIRVModuleAnalysis.cpp

- add test cases to demonstrate use case in idot.ll
- fix spirv version
- add missing check
@inbelic inbelic force-pushed the inbelic/dot-product branch from e404436 to e52cf42 Compare November 7, 2024 23:19
Copy link

@pow2clk pow2clk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall. Just a couple clarifying questions and one NFC suggestion.

assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
const MachineInstr *InputInstr = MRI.getVRegDef(MI.getOperand(2).getReg());
assert(InputInstr->getOperand(1).isReg() &&
"Unexpected operand in dot input");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This represents a complete change of which instruction we are vetting here. Can you explain why the change was needed and maybe why it worked before, but we don't need to verify that it does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

The previous implementation is incorrect as the return register of the dot op is always a scalar integer type. Instead we need to check the type of the input operands as this is what the capability is correctly checked on. I think the terminology of the InputInstr is confusing then, as we are just checking the input register type and not really switching which instruction we are checking. I will fix that up to help with code readability.

The reason this wasn't caught was because the dot4add intrinsics only has i32 inputs and so the other code-paths were not tested.

Unfortunately, I checked-in untested code, so a better way forward would have been to only add the DotProductInput4x8BitPacked capability in the previous dot4add commits and then added the other ones here.

@s-perron s-perron self-requested a review November 20, 2024 18:40
- add clarifying comment/code
Copy link

@pow2clk pow2clk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks for the responses.

@inbelic inbelic merged commit dcd69dd into llvm:main Nov 21, 2024
9 checks passed
@inbelic inbelic deleted the inbelic/dot-product branch November 25, 2024 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Update lowering of spv_[u|s]dot intrinsics to use SPIRV dot product instructions

4 participants