|
| 1 | +#include "TritonAMDGPUToLLVM/Passes.h" |
| 2 | +#include "llvm/IR/IRBuilder.h" |
| 3 | +#include "llvm/IR/Instructions.h" |
| 4 | +#include "llvm/IR/PatternMatch.h" |
| 5 | +#include "llvm/IR/Verifier.h" |
| 6 | +#include "llvm/Passes/PassBuilder.h" |
| 7 | + |
| 8 | +#define DEBUG_TYPE "tritonamdgpu-scalarize-packed-fops" |
| 9 | + |
| 10 | +using namespace llvm; |
| 11 | +using namespace llvm::PatternMatch; |
| 12 | + |
| 13 | +namespace { |
| 14 | + |
| 15 | +bool isMFMAorWMMA(Instruction &inst) { |
| 16 | + auto *callInst = llvm::dyn_cast<CallInst>(&inst); |
| 17 | + if (!callInst) |
| 18 | + return false; |
| 19 | + // E.g., tail call void asm sideeffect "s_waitcnt lgkmcnt(0) ", ""() |
| 20 | + if (callInst->isInlineAsm()) |
| 21 | + return false; |
| 22 | + Function *calledFunc = callInst->getCalledFunction(); |
| 23 | + if (!calledFunc->isIntrinsic()) |
| 24 | + return false; |
| 25 | + StringRef intrinName = calledFunc->getName(); |
| 26 | + if (intrinName.contains("mfma") || intrinName.contains("wmma")) |
| 27 | + return true; |
| 28 | + return false; |
| 29 | +} |
| 30 | + |
| 31 | +bool maybeReplaceVectorFOpWithScalarFOps(Instruction *inst, |
| 32 | + IRBuilder<> &builder) { |
| 33 | + Value *lhs, *rhs; |
| 34 | + if (!match(inst, m_BinOp(m_Value(lhs), m_Value(rhs)))) |
| 35 | + return false; |
| 36 | + auto *VecLhs = dyn_cast<VectorType>(lhs->getType()); |
| 37 | + if (!VecLhs) |
| 38 | + return false; |
| 39 | + assert(!VecLhs->isScalableTy() && "expected fixed-len vector"); |
| 40 | + builder.SetInsertPoint(inst); |
| 41 | + Value *newVec = llvm::UndefValue::get(VecLhs); |
| 42 | + for (int i = 0; i < VecLhs->getElementCount().getFixedValue(); ++i) { |
| 43 | + Value *newLhs = builder.CreateExtractElement(lhs, i); |
| 44 | + Value *newRhs = builder.CreateExtractElement(rhs, i); |
| 45 | + Value *res; |
| 46 | + if (inst->getOpcode() == Instruction::FMul) |
| 47 | + res = builder.CreateFMul(newLhs, newRhs); |
| 48 | + else if (inst->getOpcode() == Instruction::FAdd) |
| 49 | + res = builder.CreateFAdd(newLhs, newRhs); |
| 50 | + else if (inst->getOpcode() == Instruction::FSub) |
| 51 | + res = builder.CreateFSub(newLhs, newRhs); |
| 52 | + else |
| 53 | + llvm::report_fatal_error("only fadd, fmul, fsub supported"); |
| 54 | + newVec = builder.CreateInsertElement(newVec, res, i); |
| 55 | + } |
| 56 | + LLVM_DEBUG(dbgs() << "ScalarizePackedFOps: Replacing: " << inst << '\n'); |
| 57 | + LLVM_DEBUG(dbgs() << " With: " << newVec << '\n'); |
| 58 | + inst->replaceAllUsesWith(newVec); |
| 59 | + return true; |
| 60 | +} |
| 61 | + |
| 62 | +// This Pass scalarizes vector `fmul`s and `fadd`s in basic blocks that contain |
| 63 | +// MFMAs. The point/purpose/value of doing is that these get codegened to |
| 64 | +// "packed" ops (`v_pk_mul_f32`/`v_pk_add_f32`) and while packed ops use |
| 65 | +// separate VALUs from MFMA tensor cores (no problem there), the instructions |
| 66 | +// themselves cannot be *issued* in parallel, thus there is a performance cost |
| 67 | +// to having such packed ops "near" MFMAs. Concretely/specifically this |
| 68 | +// eliminates `v_pk_mul_f32`/`v_pk_add_f32` operations in the final asm in bbs |
| 69 | +// with MFMAs. |
| 70 | +// |
| 71 | +// Note, these "scalar" floating point ops will still get lowered to vector |
| 72 | +// instructions like `v_mul_f32_e32 v1, v163, v114` and |
| 73 | +// `v_add_u32_e32 v1, s16, v12`, just not the "packed" variants. |
| 74 | +// |
| 75 | +// Note, these vectorized `fmul`s aren't actually emitted by triton per se - |
| 76 | +// they are introduced/inserted by the VectorCombine::foldPermuteOfBinops |
| 77 | +// pattern during the `optimize_module` pipeline (hence why this LLVM pass |
| 78 | +// needs to follow that pipeline). |
| 79 | +struct ScalarizePackedFOps : FunctionPass { |
| 80 | + ScalarizePackedFOps() : FunctionPass(ID) {} |
| 81 | + |
| 82 | + bool runOnFunction(Function &F) override { |
| 83 | + IRBuilder<> builder(F.getContext()); |
| 84 | + bool changed = false; |
| 85 | + SmallVector<Instruction *> instsToErase; |
| 86 | + for (BasicBlock &BB : F) { |
| 87 | + if (!llvm::any_of(BB, isMFMAorWMMA)) |
| 88 | + continue; |
| 89 | + for (Instruction &inst : BB) { |
| 90 | + if (inst.getOpcode() != Instruction::FMul && |
| 91 | + inst.getOpcode() != Instruction::FAdd && |
| 92 | + inst.getOpcode() != Instruction::FSub) |
| 93 | + continue; |
| 94 | + if (maybeReplaceVectorFOpWithScalarFOps(&inst, builder)) { |
| 95 | + instsToErase.push_back(&inst); |
| 96 | + changed = true; |
| 97 | + } |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + if (changed) { |
| 102 | + for (Instruction *inst : instsToErase) { |
| 103 | + if (inst) |
| 104 | + inst->eraseFromParent(); |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + // We don't do anything with this but this is a virtual function override |
| 109 | + // and the signature requires it. |
| 110 | + return changed; |
| 111 | + } |
| 112 | + |
| 113 | + static char ID; |
| 114 | +}; |
| 115 | + |
| 116 | +} // end anonymous namespace |
| 117 | + |
| 118 | +char ScalarizePackedFOps::ID = 0; |
| 119 | + |
| 120 | +namespace mlir::triton::AMD { |
| 121 | +void runScalarizePackedFOpsPass(Function &F) { |
| 122 | + ScalarizePackedFOps pass; |
| 123 | + pass.runOnFunction(F); |
| 124 | + // If there are no errors, the function returns false. |
| 125 | + assert(!llvm::verifyFunction(F) && |
| 126 | + "expected function to verify successfully"); |
| 127 | +} |
| 128 | +} // namespace mlir::triton::AMD |
0 commit comments