Skip to content

Commit 3857d17

Browse files
committed
move dot product requirements out of switch
1 parent efea661 commit 3857d17

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,32 @@ void addOpAccessChainReqs(const MachineInstr &Instr,
10071007
}
10081008
}
10091009

1010+
static void AddDotProductRequirements(const MachineInstr &MI,
1011+
SPIRV::RequirementHandler &Reqs,
1012+
const SPIRVSubtarget &ST) {
1013+
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
1014+
Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
1015+
Reqs.addCapability(SPIRV::Capability::DotProductKHR);
1016+
1017+
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1018+
const MachineInstr *InstrPtr = &MI;
1019+
assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");
1020+
1021+
Register TypeReg = InstrPtr->getOperand(1).getReg();
1022+
SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
1023+
if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1024+
assert(TypeDef->getOperand(1).getImm() == 32);
1025+
Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPackedKHR);
1026+
} else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1027+
SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
1028+
assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1029+
auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
1030+
? SPIRV::Capability::DotProductInput4x8BitKHR
1031+
: SPIRV::Capability::DotProductInputAllKHR;
1032+
Reqs.addCapability(Capability);
1033+
}
1034+
}
1035+
10101036
void addInstrRequirements(const MachineInstr &MI,
10111037
SPIRV::RequirementHandler &Reqs,
10121038
const SPIRVSubtarget &ST) {
@@ -1384,31 +1410,9 @@ void addInstrRequirements(const MachineInstr &MI,
13841410
}
13851411
break;
13861412
case SPIRV::OpSDot:
1387-
case SPIRV::OpUDot: {
1388-
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
1389-
Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
1390-
Reqs.addCapability(SPIRV::Capability::DotProductKHR);
1391-
1392-
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1393-
const MachineInstr *InstrPtr = &MI;
1394-
assert(MI.getOperand(1).isReg() && "Unexpected operand in dot");
1395-
1396-
Register TypeReg = InstrPtr->getOperand(1).getReg();
1397-
SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
1398-
if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1399-
assert(TypeDef->getOperand(1).getImm() == 32);
1400-
Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPackedKHR);
1401-
} else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1402-
SPIRVType *ScalarTypeDef =
1403-
MRI.getVRegDef(TypeDef->getOperand(1).getReg());
1404-
assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1405-
auto Capability = ScalarTypeDef->getOperand(1).getImm() == 8
1406-
? SPIRV::Capability::DotProductInput4x8BitKHR
1407-
: SPIRV::Capability::DotProductInputAllKHR;
1408-
Reqs.addCapability(Capability);
1409-
}
1413+
case SPIRV::OpUDot:
1414+
AddDotProductRequirements(MI, Reqs, ST);
14101415
break;
1411-
}
14121416
default:
14131417
break;
14141418
}

0 commit comments

Comments
 (0)