Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
MachineInstr &I, unsigned Opcode) const;

bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
MachineInstr &I, bool Signed) const;

bool selectIntegerDotExpansion(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

template <bool Signed>
bool selectDot4AddPacked(Register ResVReg, const SPIRVType *ResType,
Expand Down Expand Up @@ -1705,11 +1708,28 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}

// Since pre-1.6 SPIRV has no integer dot implementation,
// expand by piecewise multiplying and adding the results
bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
MachineInstr &I,
bool Signed) const {
assert(I.getNumOperands() == 4);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());
MachineBasicBlock &BB = *I.getParent();

auto DotOp = Signed ? SPIRV::OpSDot : SPIRV::OpUDot;
return BuildMI(BB, I, I.getDebugLoc(), TII.get(DotOp))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg())
.constrainAllUses(TII, TRI, RBI);
}

// Since pre-1.6 SPIRV has no integer dot implementation,
// expand by piecewise multiplying and adding the results
bool SPIRVInstructionSelector::selectIntegerDotExpansion(
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
assert(I.getNumOperands() == 4);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());
Expand Down Expand Up @@ -2755,7 +2775,11 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectFloatDot(ResVReg, ResType, I);
case Intrinsic::spv_udot:
case Intrinsic::spv_sdot:
return selectIntegerDot(ResVReg, ResType, I);
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
return selectIntegerDot(ResVReg, ResType, I,
/*Signed=*/IID == Intrinsic::spv_sdot);
return selectIntegerDotExpansion(ResVReg, ResType, I);
case Intrinsic::spv_dot4add_i8packed:
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
STI.isAtLeastSPIRVVer(VersionTuple(1, 6)))
Expand Down
24 changes: 15 additions & 9 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1029,21 +1029,27 @@ static void AddDotProductRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::DotProduct);

const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
const MachineInstr *InstrPtr = &MI;
assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");

Register TypeReg = InstrPtr->getOperand(1).getReg();
SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
// We do not consider what the previous instruction is. This is just used
// to get the input register and to check the type.
const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg());
assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
Register InputReg = Input->getOperand(1).getReg();

SPIRVType *TypeDef = MRI.getVRegDef(InputReg);
if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
assert(TypeDef->getOperand(1).getImm() == 32);
Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
} else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
? SPIRV::Capability::DotProductInput4x8Bit
: SPIRV::Capability::DotProductInputAll;
Reqs.addCapability(Capability);
if (ScalarTypeDef->getOperand(1).getImm() == 8) {
assert(TypeDef->getOperand(2).getImm() == 4 &&
"Dot operand of 8-bit integer type requires 4 components");
Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
} else {
Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
}
}
}

Expand Down
123 changes: 91 additions & 32 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-EXP
; RUN: llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
; RUN: llc -O0 -mtriple=spirv1.5-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-EXT
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-unknown -spirv-ext=+SPV_KHR_integer_dot_product %s -o - -filetype=obj | spirv-val %}

; Make sure dxil operation function calls for dot are generated for int/uint vectors.

; CHECK-DAG: OpCapability Int8
; CHECK-DOT-DAG: OpCapability DotProduct
; CHECK-DOT-DAG: OpCapability DotProductInputAll
; CHECK-DOT-DAG: OpCapability DotProductInput4x8Bit
; CHECK-EXT-DAG: OpExtension "SPV_KHR_integer_dot_product"

; CHECK-DAG: %[[#int_8:]] = OpTypeInt 8
; CHECK-DAG: %[[#vec4_int_8:]] = OpTypeVector %[[#int_8]] 4
; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16
; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2
; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3
Expand All @@ -11,14 +23,32 @@
; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64
; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2

define noundef i8 @dot_int8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]]

; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_8]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]]
%dot = call i8 @llvm.spv.sdot.v4i8(<4 x i8> %a, <4 x i8> %b)
ret i8 %dot
}

define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]

; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_16]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
%dot = call i16 @llvm.spv.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
ret i16 %dot
}
Expand All @@ -27,28 +57,49 @@ define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]

; CHECK-DOT: %[[#dot:]] = OpSDot %[[#int_32]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
%dot = call i32 @llvm.spv.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
ret i32 %dot
}

define noundef i8 @dot_uint8_t4(<4 x i8> noundef %a, <4 x i8> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_8]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_8]]

; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_8]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_8]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_8]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum:]] = OpIAdd %[[#int_8]] %[[#elt0]] %[[#elt1]]
%dot = call i8 @llvm.spv.udot.v4i8(<4 x i8> %a, <4 x i8> %b)
ret i8 %dot
}

define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]

; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_16]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
%dot = call i16 @llvm.spv.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
ret i16 %dot
}
Expand All @@ -57,14 +108,17 @@ define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]

; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_32]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
; CHECK-EXP: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
; CHECK-EXP: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
; CHECK-EXP: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
; CHECK-EXP: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
%dot = call i32 @llvm.spv.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
ret i32 %dot
}
Expand All @@ -73,16 +127,21 @@ define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]

; CHECK-DOT: %[[#dot:]] = OpUDot %[[#int_64]] %[[#arg0]] %[[#arg1]]

; CHECK-EXP: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
; CHECK-EXP: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
; CHECK-EXP: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
; CHECK-EXP: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
%dot = call i64 @llvm.spv.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
ret i64 %dot
}

declare i8 @llvm.spv.sdot.v4i8(<4 x i8>, <4 x i8>)
declare i16 @llvm.spv.sdot.v2i16(<2 x i16>, <2 x i16>)
declare i32 @llvm.spv.sdot.v4i32(<4 x i32>, <4 x i32>)
declare i8 @llvm.spv.udot.v4i8(<4 x i8>, <4 x i8>)
declare i16 @llvm.spv.udot.v3i32(<3 x i16>, <3 x i16>)
declare i32 @llvm.spv.udot.v4i32(<4 x i32>, <4 x i32>)
declare i64 @llvm.spv.udot.v2i64(<2 x i64>, <2 x i64>)
Loading