@@ -1385,13 +1385,28 @@ void addInstrRequirements(const MachineInstr &MI,
13851385 break ;
13861386 case SPIRV::OpSDot:
13871387 case SPIRV::OpUDot: {
1388- if (ST.canUseExtension (SPIRV::Extension::SPV_KHR_integer_dot_product)) {
1388+ if (ST.canUseExtension (SPIRV::Extension::SPV_KHR_integer_dot_product))
13891389 Reqs.addExtension (SPIRV::Extension::SPV_KHR_integer_dot_product);
1390- }
13911390 Reqs.addCapability (SPIRV::Capability::DotProductKHR);
1392- Reqs.addCapability (SPIRV::Capability::DotProductInputAllKHR);
1393- Reqs.addCapability (SPIRV::Capability::DotProductInput4x8BitKHR);
1394- Reqs.addCapability (SPIRV::Capability::DotProductInput4x8BitPackedKHR);
1391+
1392+ const MachineRegisterInfo &MRI = MI.getMF ()->getRegInfo ();
1393+ const MachineInstr *InstrPtr = &MI;
1394+ assert (MI.getOperand (1 ).isReg () && " Unexpected operand in dot" );
1395+
1396+ Register TypeReg = InstrPtr->getOperand (1 ).getReg ();
1397+ SPIRVType *TypeDef = MRI.getVRegDef (TypeReg);
1398+ if (TypeDef->getOpcode () == SPIRV::OpTypeInt) {
1399+ assert (TypeDef->getOperand (1 ).getImm () == 32 );
1400+ Reqs.addCapability (SPIRV::Capability::DotProductInput4x8BitPackedKHR);
1401+ } else if (TypeDef->getOpcode () == SPIRV::OpTypeVector) {
1402+ SPIRVType *ScalarTypeDef =
1403+ MRI.getVRegDef (TypeDef->getOperand (1 ).getReg ());
1404+ assert (ScalarTypeDef->getOpcode () == SPIRV::OpTypeInt);
1405+ auto Capability = ScalarTypeDef->getOperand (1 ).getImm () == 8
1406+ ? SPIRV::Capability::DotProductInput4x8BitKHR
1407+ : SPIRV::Capability::DotProductInputAllKHR;
1408+ Reqs.addCapability (Capability);
1409+ }
13951410 break ;
13961411 }
13971412 default :
0 commit comments