Skip to content

Commit 55ab633

Browse files
committed
Add predicate for establishing target's support for BF16.
1 parent 965e1fa commit 55ab633

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ void IRTranslator::addMachineCFGPred(CFGEdge Edge, MachineBasicBlock *NewPred) {
294294
MachinePreds[Edge].push_back(NewPred);
295295
}
296296

297+
static bool targetSupportsBF16Type(const MachineFunction *MF) {
298+
return MF->getTarget().getTargetTriple().isSPIRV();
299+
}
300+
297301
static bool containsBF16Type(const User &U) {
298302
// BF16 cannot currently be represented by LLT, to avoid miscompiles we
299303
// prevent any instructions using them. FIXME: This can be removed once LLT
@@ -306,7 +310,7 @@ static bool containsBF16Type(const User &U) {
306310

307311
bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
308312
MachineIRBuilder &MIRBuilder) {
309-
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
313+
if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
310314
return false;
311315

312316
// Get or create a virtual register for each value.
@@ -328,7 +332,7 @@ bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
328332

329333
bool IRTranslator::translateUnaryOp(unsigned Opcode, const User &U,
330334
MachineIRBuilder &MIRBuilder) {
331-
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
335+
if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
332336
return false;
333337

334338
Register Op0 = getOrCreateVReg(*U.getOperand(0));
@@ -348,7 +352,7 @@ bool IRTranslator::translateFNeg(const User &U, MachineIRBuilder &MIRBuilder) {
348352

349353
bool IRTranslator::translateCompare(const User &U,
350354
MachineIRBuilder &MIRBuilder) {
351-
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
355+
if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
352356
return false;
353357

354358
auto *CI = cast<CmpInst>(&U);
@@ -1569,7 +1573,7 @@ bool IRTranslator::translateBitCast(const User &U,
15691573

15701574
bool IRTranslator::translateCast(unsigned Opcode, const User &U,
15711575
MachineIRBuilder &MIRBuilder) {
1572-
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
1576+
if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
15731577
return false;
15741578

15751579
uint32_t Flags = 0;
@@ -2688,7 +2692,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
26882692

26892693
bool IRTranslator::translateInlineAsm(const CallBase &CB,
26902694
MachineIRBuilder &MIRBuilder) {
2691-
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(CB))
2695+
if (containsBF16Type(CB) && !targetSupportsBF16Type(MF))
26922696
return false;
26932697

26942698
const InlineAsmLowering *ALI = MF->getSubtarget().getInlineAsmLowering();
@@ -2779,7 +2783,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
27792783
}
27802784

27812785
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
2782-
if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
2786+
if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
27832787
return false;
27842788

27852789
const CallInst &CI = cast<CallInst>(U);

0 commit comments

Comments
 (0)