@@ -1007,6 +1007,32 @@ void addOpAccessChainReqs(const MachineInstr &Instr,
10071007 }
10081008}
10091009
1010+ static void AddDotProductRequirements (const MachineInstr &MI,
1011+ SPIRV::RequirementHandler &Reqs,
1012+ const SPIRVSubtarget &ST) {
1013+ if (ST.canUseExtension (SPIRV::Extension::SPV_KHR_integer_dot_product))
1014+ Reqs.addExtension (SPIRV::Extension::SPV_KHR_integer_dot_product);
1015+ Reqs.addCapability (SPIRV::Capability::DotProductKHR);
1016+
1017+ const MachineRegisterInfo &MRI = MI.getMF ()->getRegInfo ();
1018+ const MachineInstr *InstrPtr = &MI;
1019+ assert (MI.getOperand (1 ).isReg () && " Unexpected operand in dot" );
1020+
1021+ Register TypeReg = InstrPtr->getOperand (1 ).getReg ();
1022+ SPIRVType *TypeDef = MRI.getVRegDef (TypeReg);
1023+ if (TypeDef->getOpcode () == SPIRV::OpTypeInt) {
1024+ assert (TypeDef->getOperand (1 ).getImm () == 32 );
1025+ Reqs.addCapability (SPIRV::Capability::DotProductInput4x8BitPackedKHR);
1026+ } else if (TypeDef->getOpcode () == SPIRV::OpTypeVector) {
1027+ SPIRVType *ScalarTypeDef = MRI.getVRegDef (TypeDef->getOperand (1 ).getReg ());
1028+ assert (ScalarTypeDef->getOpcode () == SPIRV::OpTypeInt);
1029+ auto Capability = ScalarTypeDef->getOperand (1 ).getImm () == 8
1030+ ? SPIRV::Capability::DotProductInput4x8BitKHR
1031+ : SPIRV::Capability::DotProductInputAllKHR;
1032+ Reqs.addCapability (Capability);
1033+ }
1034+ }
1035+
10101036void addInstrRequirements (const MachineInstr &MI,
10111037 SPIRV::RequirementHandler &Reqs,
10121038 const SPIRVSubtarget &ST) {
@@ -1384,31 +1410,9 @@ void addInstrRequirements(const MachineInstr &MI,
13841410 }
13851411 break ;
13861412 case SPIRV::OpSDot:
1387- case SPIRV::OpUDot: {
1388- if (ST.canUseExtension (SPIRV::Extension::SPV_KHR_integer_dot_product))
1389- Reqs.addExtension (SPIRV::Extension::SPV_KHR_integer_dot_product);
1390- Reqs.addCapability (SPIRV::Capability::DotProductKHR);
1391-
1392- const MachineRegisterInfo &MRI = MI.getMF ()->getRegInfo ();
1393- const MachineInstr *InstrPtr = &MI;
1394- assert (MI.getOperand (1 ).isReg () && " Unexpected operand in dot" );
1395-
1396- Register TypeReg = InstrPtr->getOperand (1 ).getReg ();
1397- SPIRVType *TypeDef = MRI.getVRegDef (TypeReg);
1398- if (TypeDef->getOpcode () == SPIRV::OpTypeInt) {
1399- assert (TypeDef->getOperand (1 ).getImm () == 32 );
1400- Reqs.addCapability (SPIRV::Capability::DotProductInput4x8BitPackedKHR);
1401- } else if (TypeDef->getOpcode () == SPIRV::OpTypeVector) {
1402- SPIRVType *ScalarTypeDef =
1403- MRI.getVRegDef (TypeDef->getOperand (1 ).getReg ());
1404- assert (ScalarTypeDef->getOpcode () == SPIRV::OpTypeInt);
1405- auto Capability = ScalarTypeDef->getOperand (1 ).getImm () == 8
1406- ? SPIRV::Capability::DotProductInput4x8BitKHR
1407- : SPIRV::Capability::DotProductInputAllKHR;
1408- Reqs.addCapability (Capability);
1409- }
1413+ case SPIRV::OpUDot:
1414+ AddDotProductRequirements (MI, Reqs, ST);
14101415 break ;
1411- }
14121416 default :
14131417 break ;
14141418 }
0 commit comments