@@ -294,6 +294,10 @@ void IRTranslator::addMachineCFGPred(CFGEdge Edge, MachineBasicBlock *NewPred) {
294294 MachinePreds[Edge].push_back (NewPred);
295295}
296296
297+ static bool targetSupportsBF16Type (const MachineFunction *MF) {
298+ return MF->getTarget ().getTargetTriple ().isSPIRV ();
299+ }
300+
297301static bool containsBF16Type (const User &U) {
298302 // BF16 cannot currently be represented by LLT, to avoid miscompiles we
299303 // prevent any instructions using them. FIXME: This can be removed once LLT
@@ -306,7 +310,7 @@ static bool containsBF16Type(const User &U) {
306310
307311bool IRTranslator::translateBinaryOp (unsigned Opcode, const User &U,
308312 MachineIRBuilder &MIRBuilder) {
309- if (!MF-> getTarget (). getTargetTriple (). isSPIRV () && containsBF16Type (U ))
313+ if (containsBF16Type (U) && ! targetSupportsBF16Type (MF ))
310314 return false ;
311315
312316 // Get or create a virtual register for each value.
@@ -328,7 +332,7 @@ bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
328332
329333bool IRTranslator::translateUnaryOp (unsigned Opcode, const User &U,
330334 MachineIRBuilder &MIRBuilder) {
331- if (!MF-> getTarget (). getTargetTriple (). isSPIRV () && containsBF16Type (U ))
335+ if (containsBF16Type (U) && ! targetSupportsBF16Type (MF ))
332336 return false ;
333337
334338 Register Op0 = getOrCreateVReg (*U.getOperand (0 ));
@@ -348,7 +352,7 @@ bool IRTranslator::translateFNeg(const User &U, MachineIRBuilder &MIRBuilder) {
348352
349353bool IRTranslator::translateCompare (const User &U,
350354 MachineIRBuilder &MIRBuilder) {
351- if (!MF-> getTarget (). getTargetTriple (). isSPIRV () && containsBF16Type (U ))
355+ if (containsBF16Type (U) && ! targetSupportsBF16Type (MF ))
352356 return false ;
353357
354358 auto *CI = cast<CmpInst>(&U);
@@ -1569,7 +1573,7 @@ bool IRTranslator::translateBitCast(const User &U,
15691573
15701574bool IRTranslator::translateCast (unsigned Opcode, const User &U,
15711575 MachineIRBuilder &MIRBuilder) {
1572- if (!MF-> getTarget (). getTargetTriple (). isSPIRV () && containsBF16Type (U ))
1576+ if (containsBF16Type (U) && ! targetSupportsBF16Type (MF ))
15731577 return false ;
15741578
15751579 uint32_t Flags = 0 ;
@@ -2688,7 +2692,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
26882692
26892693bool IRTranslator::translateInlineAsm (const CallBase &CB,
26902694 MachineIRBuilder &MIRBuilder) {
2691- if (!MF-> getTarget (). getTargetTriple (). isSPIRV () && containsBF16Type (CB ))
2695+ if (containsBF16Type (CB) && ! targetSupportsBF16Type (MF ))
26922696 return false ;
26932697
26942698 const InlineAsmLowering *ALI = MF->getSubtarget ().getInlineAsmLowering ();
@@ -2779,7 +2783,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
27792783}
27802784
27812785bool IRTranslator::translateCall (const User &U, MachineIRBuilder &MIRBuilder) {
2782- if (!MF-> getTarget (). getTargetTriple (). isSPIRV () && containsBF16Type (U ))
2786+ if (containsBF16Type (U) && ! targetSupportsBF16Type (MF ))
27832787 return false ;
27842788
27852789 const CallInst &CI = cast<CallInst>(U);
0 commit comments