Skip to content

Commit 2b5ebd4

Browse files
committed
[SPIRV] Use Op[S|U]Dot when possible for integer dot product
- use the new OpSDot/OpUDot instructions when capabilites allow in SPIRVInstructionSelector.cpp - correct functionality of capability check onto input operand and not return operand type in SPIRVModuleAnalysis.cpp - add test cases to demonstrate use case in idot.ll
1 parent 3cdac06 commit 2b5ebd4

File tree

3 files changed

+128
-44
lines changed

3 files changed

+128
-44
lines changed

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
162162
MachineInstr &I, unsigned Opcode) const;
163163

164164
bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
165-
MachineInstr &I) const;
165+
MachineInstr &I, bool Signed) const;
166+
167+
bool selectIntegerDotExpansion(Register ResVReg, const SPIRVType *ResType,
168+
MachineInstr &I) const;
166169

167170
template <bool Signed>
168171
bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType,
@@ -1640,9 +1643,27 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
16401643
.constrainAllUses(TII, TRI, RBI);
16411644
}
16421645

1646+
bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
1647+
const SPIRVType *ResType,
1648+
MachineInstr &I,
1649+
bool Signed) const {
1650+
assert(I.getNumOperands() == 4);
1651+
assert(I.getOperand(2).isReg());
1652+
assert(I.getOperand(3).isReg());
1653+
MachineBasicBlock &BB = *I.getParent();
1654+
1655+
auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
1656+
return BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
1657+
.addDef(ResVReg)
1658+
.addUse(GR.getSPIRVTypeID(ResType))
1659+
.addUse(I.getOperand(2).getReg())
1660+
.addUse(I.getOperand(3).getReg())
1661+
.constrainAllUses(TII, TRI, RBI);
1662+
}
1663+
16431664
// Since pre-1.6 SPIRV has no integer dot implementation,
16441665
// expand by piecewise multiplying and adding the results
1645-
bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
1666+
bool SPIRVInstructionSelector::selectIntegerDotExpansion(Register ResVReg,
16461667
const SPIRVType *ResType,
16471668
MachineInstr &I) const {
16481669
assert(I.getNumOperands() == 4);
@@ -2640,7 +2661,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
26402661
return selectFloatDot(ResVReg, ResType, I);
26412662
case Intrinsic::spv_udot:
26422663
case Intrinsic::spv_sdot:
2643-
return selectIntegerDot(ResVReg, ResType, I);
2664+
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
2665+
STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
2666+
return selectIntegerDot(ResVReg, ResType, I, /*Signed=*/IID == Intrinsic::spv_sdot);
2667+
return selectIntegerDotExpansion(ResVReg, ResType, I);
26442668
case Intrinsic::spv_dot4add_i8packed:
26452669
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
26462670
STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,21 +1013,25 @@ static void AddDotProductRequirements(const MachineInstr &MI,
10131013
Reqs.addCapability(SPIRV::Capability::DotProduct);
10141014

10151015
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1016-
const MachineInstr *InstrPtr = &MI;
1017-
assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");
1016+
assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1017+
const MachineInstr *InputInstr = MRI.getVRegDef(MI.getOperand(2).getReg());
1018+
assert(InputInstr->getOperand(1).isReg() && "Unexpected operand in dot input");
10181019

1019-
Register TypeReg = InstrPtr->getOperand(1).getReg();
1020+
Register TypeReg = InputInstr->getOperand(1).getReg();
10201021
SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
10211022
if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
10221023
assert(TypeDef->getOperand(1).getImm() == 32);
10231024
Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
10241025
} else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
10251026
SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
10261027
assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1027-
auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
1028-
? SPIRV::Capability::DotProductInput4x8Bit
1029-
: SPIRV::Capability::DotProductInputAll;
1030-
Reqs.addCapability(Capability);
1028+
if (ScalarTypeDef->getOperand(1).getImm() == 8) {
1029+
assert(TypeDef->getOperand(2).getImm() == 4
1030+
&& "Dot operand of 8-bit integer type requires 4 components");
1031+
Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
1032+
} else {
1033+
Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
1034+
}
10311035
}
10321036
}
10331037

Lines changed: 90 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
1-
; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
2-
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
1+
; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
2+
; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
3+
; RUN: llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
4+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %}
5+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
6+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}
37

48
; Make sure dxil operation function calls for dot are generated for int/uint vectors.
59

10+
; CHECK-DAG: OpCapability Int8
11+
; CHECK-DOT-DAG: OpCapability DotProduct
12+
; CHECK-DOT-DAG: OpCapability DotProductInputAll
13+
; CHECK-DOT-DAG: OpCapability DotProductInput4x8Bit
14+
; CHECK-EXT-DAG: OpExtension "SPV_KHR_integer_dot_product"
15+
16+
; CHECK-DAG: %[[#int_8:]] = OpTypeInt 8
17+
; CHECK-DAG: %[[#vec4_int_8:]] = OpTypeVector %[[#int_8]] 4
618
; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16
719
; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2
820
; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3
@@ -11,14 +23,32 @@
1123
; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64
1224
; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2
1325

26+
define noundef i8 @dot_int8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) {
27+
entry:
28+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]]
29+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]]
30+
31+
; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_8]] %[[#arg0]] %[[#arg1]]
32+
33+
; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]]
34+
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0
35+
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1
36+
; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]]
37+
%dot = call i8 @llvm.spv.sdot.v4i8(<4 x i8> %a, <4 x i8> %b)
38+
ret i8 %dot
39+
}
40+
1441
define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
1542
entry:
1643
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]]
1744
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]]
18-
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
19-
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
20-
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
21-
; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
45+
46+
; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_16]] %[[#arg0]] %[[#arg1]]
47+
48+
; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
49+
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
50+
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
51+
; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
2252
%dot = call i16 @llvm.spv.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
2353
ret i16 %dot
2454
}
@@ -27,28 +57,49 @@ define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
2757
entry:
2858
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
2959
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
30-
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
31-
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
32-
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
33-
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
34-
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
35-
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
36-
; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
37-
; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
60+
61+
; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_32]] %[[#arg0]] %[[#arg1]]
62+
63+
; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
64+
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
65+
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
66+
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
67+
; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
68+
; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
69+
; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
70+
; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
3871
%dot = call i32 @llvm.spv.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
3972
ret i32 %dot
4073
}
4174

75+
define noundef i8 @dot_uint8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) {
76+
entry:
77+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]]
78+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]]
79+
80+
; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_8]] %[[#arg0]] %[[#arg1]]
81+
82+
; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]]
83+
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0
84+
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1
85+
; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]]
86+
%dot = call i8 @llvm.spv.udot.v4i8(<4 x i8> %a, <4 x i8> %b)
87+
ret i8 %dot
88+
}
89+
4290
define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
4391
entry:
4492
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]]
4593
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]]
46-
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
47-
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
48-
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
49-
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
50-
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
51-
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
94+
95+
; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_16]] %[[#arg0]] %[[#arg1]]
96+
97+
; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
98+
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
99+
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
100+
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
101+
; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
102+
; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
52103
%dot = call i16 @llvm.spv.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
53104
ret i16 %dot
54105
}
@@ -57,32 +108,37 @@ define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
57108
entry:
58109
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
59110
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
60-
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
61-
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
62-
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
63-
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
64-
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
65-
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
66-
; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
67-
; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
111+
112+
; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_32]] %[[#arg0]] %[[#arg1]]
113+
114+
; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
115+
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
116+
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
117+
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
118+
; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
119+
; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
120+
; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
121+
; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
68122
%dot = call i32 @llvm.spv.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
69123
ret i32 %dot
70124
}
71125

72126
define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
73127
entry:
74-
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
75-
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
76-
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
77-
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
78-
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
79-
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
128+
; CHECK-EXP: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
129+
; CHECK-EXP: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
130+
; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
131+
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
132+
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
133+
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
80134
%dot = call i64 @llvm.spv.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
81135
ret i64 %dot
82136
}
83137

138+
declare i8 @llvm.spv.sdot.v4i8(<4 x i8>, <4 x i8>)
84139
declare i16 @llvm.spv.sdot.v2i16(<2 x i16>, <2 x i16>)
85140
declare i32 @llvm.spv.sdot.v4i32(<4 x i32>, <4 x i32>)
141+
declare i8 @llvm.spv.udot.v4i8(<4 x i8>, <4 x i8>)
86142
declare i16 @llvm.spv.udot.v3i32(<3 x i16>, <3 x i16>)
87143
declare i32 @llvm.spv.udot.v4i32(<4 x i32>, <4 x i32>)
88144
declare i64 @llvm.spv.udot.v2i64(<2 x i64>, <2 x i64>)

0 commit comments

Comments
 (0)