Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
MachineInstr &I, unsigned Opcode) const;

bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
MachineInstr &I, bool Signed) const;

bool selectIntegerDotExpansion(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

template <bool Signed>
bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType,
Expand Down Expand Up @@ -1681,11 +1684,28 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}

// Since pre-1.6 SPIRV has no integer dot implementation,
// expand by piecewise multiplying and adding the results
bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
MachineInstr &I,
bool Signed) const {
assert(I.getNumOperands() == 4);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());
MachineBasicBlock &BB = *I.getParent();

auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
return BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg())
.constrainAllUses(TII, TRI, RBI);
}

// Since pre-1.6 SPIRV has no integer dot implementation,
// expand by piecewise multiplying and adding the results
bool SPIRVInstructionSelector::selectIntegerDotExpansion(
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
assert(I.getNumOperands() == 4);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());
Expand Down Expand Up @@ -2681,7 +2701,11 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectFloatDot(ResVReg, ResType, I);
case Intrinsic::spv_udot:
case Intrinsic::spv_sdot:
return selectIntegerDot(ResVReg, ResType, I);
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
return selectIntegerDot(ResVReg, ResType, I,
/*Signed=*/IID == Intrinsic::spv_sdot);
return selectIntegerDotExpansion(ResVReg, ResType, I);
case Intrinsic::spv_dot4add_i8packed:
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
Expand Down
19 changes: 12 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,21 +1013,26 @@ static void AddDotProductRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::DotProduct);

const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
const MachineInstr *InstrPtr = &MI;
assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");
assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
const MachineInstr *InputInstr = MRI.getVRegDef(MI.getOperand(2).getReg());
assert(InputInstr->getOperand(1).isReg() &&
"Unexpected operand in dot input");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This represents a complete change of which instruction we are vetting here. Can you explain why the change was needed and maybe why it worked before, but we don't need to verify that it does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

The previous implementation is incorrect as the return register of the dot op is always a scalar integer type. Instead we need to check the type of the input operands as this is what the capability is correctly checked on. I think the terminology of the InputInstr is confusing then, as we are just checking the input register type and not really switching which instruction we are checking. I will fix that up to help with code readability.

The reason this wasn't caught was because the dot4add intrinsics only has i32 inputs and so the other code-paths were not tested.

Unfortunately, I checked-in untested code, so a better way forward would have been to only add the DotProductInput4x8BitPacked capability in the previous dot4add commits and then added the other ones here.


Register TypeReg = InstrPtr->getOperand(1).getReg();
Register TypeReg = InputInstr->getOperand(1).getReg();
SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
assert(TypeDef->getOperand(1).getImm() == 32);
Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
} else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
? SPIRV::Capability::DotProductInput4x8Bit
: SPIRV::Capability::DotProductInputAll;
Reqs.addCapability(Capability);
if (ScalarTypeDef->getOperand(1).getImm() == 8) {
assert(TypeDef->getOperand(2).getImm() == 4 &&
"Dot operand of 8-bit integer type requires 4 components");
Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
} else {
Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
}
}
}

Expand Down
123 changes: 91 additions & 32 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}

; Make sure dxil operation function calls for dot are generated for int/uint vectors.

; CHECK-DAG: OpCapability Int8
; CHECK-DOT-DAG: OpCapability DotProduct
; CHECK-DOT-DAG: OpCapability DotProductInputAll
; CHECK-DOT-DAG: OpCapability DotProductInput4x8Bit
; CHECK-EXT-DAG: OpExtension "SPV_KHR_integer_dot_product"

; CHECK-DAG: %[[#int_8:]] = OpTypeInt 8
; CHECK-DAG: %[[#vec4_int_8:]] = OpTypeVector %[[#int_8]] 4
; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16
; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2
; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3
Expand All @@ -11,14 +23,32 @@
; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64
; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2

define noundef i8 @dot_int8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]]

; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_8]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]]
%dot = call i8 @llvm.spv.sdot.v4i8(<4 x i8> %a, <4 x i8> %b)
ret i8 %dot
}

define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]

; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_16]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
%dot = call i16 @llvm.spv.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
ret i16 %dot
}
Expand All @@ -27,28 +57,49 @@ define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]

; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_32]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
%dot = call i32 @llvm.spv.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
ret i32 %dot
}

define noundef i8 @dot_uint8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]]

; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_8]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]]
%dot = call i8 @llvm.spv.udot.v4i8(<4 x i8> %a, <4 x i8> %b)
ret i8 %dot
}

define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]

; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_16]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
%dot = call i16 @llvm.spv.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
ret i16 %dot
}
Expand All @@ -57,14 +108,17 @@ define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]

; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_32]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
%dot = call i32 @llvm.spv.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
ret i32 %dot
}
Expand All @@ -73,16 +127,21 @@ define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]

; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_64]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
%dot = call i64 @llvm.spv.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
ret i64 %dot
}

declare i8 @llvm.spv.sdot.v4i8(<4 x i8>, <4 x i8>)
declare i16 @llvm.spv.sdot.v2i16(<2 x i16>, <2 x i16>)
declare i32 @llvm.spv.sdot.v4i32(<4 x i32>, <4 x i32>)
declare i8 @llvm.spv.udot.v4i8(<4 x i8>, <4 x i8>)
declare i16 @llvm.spv.udot.v3i32(<3 x i16>, <3 x i16>)
declare i32 @llvm.spv.udot.v4i32(<4 x i32>, <4 x i32>)
declare i64 @llvm.spv.udot.v2i64(<2 x i64>, <2 x i64>)
Loading