Skip to content

Commit 569af78

Browse files
zwu-2025paultrojahnamd
authored andcommitted
This picks up a bug fix for AMDGPU v_permlane_swap: llvm/llvm-project#144423 Without this fix, the v_permlane_swap is wrongly sunk. Along the way we need to fix API changes: Add header file for the class IRBuilder Add missing default parameter in convertFuncOpToLLVMFuncOp
1 parent 94d7559 commit 569af78

File tree

3 files changed

+131
-8
lines changed

3 files changed

+131
-8
lines changed

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8957e64a20fc7f4277565c6cfe3e555c119783ce
1+
570885128351868c1308bb22e8ca351d318bc4a1

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
1+
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
2+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
13
#include "mlir/IR/BuiltinAttributes.h"
24
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
35
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
46

5-
namespace mlir {
6-
FailureOr<LLVM::LLVMFuncOp>
7-
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
8-
ConversionPatternRewriter &rewriter,
9-
const LLVMTypeConverter &converter);
10-
}
11-
127
namespace {
138

149
using namespace mlir;
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#include "TritonAMDGPUToLLVM/Passes.h"
2+
#include "llvm/IR/IRBuilder.h"
3+
#include "llvm/IR/Instructions.h"
4+
#include "llvm/IR/PatternMatch.h"
5+
#include "llvm/IR/Verifier.h"
6+
#include "llvm/Passes/PassBuilder.h"
7+
8+
#define DEBUG_TYPE "tritonamdgpu-scalarize-packed-fops"
9+
10+
using namespace llvm;
11+
using namespace llvm::PatternMatch;
12+
13+
namespace {
14+
15+
bool isMFMAorWMMA(Instruction &inst) {
16+
auto *callInst = llvm::dyn_cast<CallInst>(&inst);
17+
if (!callInst)
18+
return false;
19+
// E.g., tail call void asm sideeffect "s_waitcnt lgkmcnt(0) ", ""()
20+
if (callInst->isInlineAsm())
21+
return false;
22+
Function *calledFunc = callInst->getCalledFunction();
23+
if (!calledFunc->isIntrinsic())
24+
return false;
25+
StringRef intrinName = calledFunc->getName();
26+
if (intrinName.contains("mfma") || intrinName.contains("wmma"))
27+
return true;
28+
return false;
29+
}
30+
31+
bool maybeReplaceVectorFOpWithScalarFOps(Instruction *inst,
32+
IRBuilder<> &builder) {
33+
Value *lhs, *rhs;
34+
if (!match(inst, m_BinOp(m_Value(lhs), m_Value(rhs))))
35+
return false;
36+
auto *VecLhs = dyn_cast<VectorType>(lhs->getType());
37+
if (!VecLhs)
38+
return false;
39+
assert(!VecLhs->isScalableTy() && "expected fixed-len vector");
40+
builder.SetInsertPoint(inst);
41+
Value *newVec = llvm::UndefValue::get(VecLhs);
42+
for (int i = 0; i < VecLhs->getElementCount().getFixedValue(); ++i) {
43+
Value *newLhs = builder.CreateExtractElement(lhs, i);
44+
Value *newRhs = builder.CreateExtractElement(rhs, i);
45+
Value *res;
46+
if (inst->getOpcode() == Instruction::FMul)
47+
res = builder.CreateFMul(newLhs, newRhs);
48+
else if (inst->getOpcode() == Instruction::FAdd)
49+
res = builder.CreateFAdd(newLhs, newRhs);
50+
else if (inst->getOpcode() == Instruction::FSub)
51+
res = builder.CreateFSub(newLhs, newRhs);
52+
else
53+
llvm::report_fatal_error("only fadd, fmul, fsub supported");
54+
newVec = builder.CreateInsertElement(newVec, res, i);
55+
}
56+
LLVM_DEBUG(dbgs() << "ScalarizePackedFOps: Replacing: " << inst << '\n');
57+
LLVM_DEBUG(dbgs() << " With: " << newVec << '\n');
58+
inst->replaceAllUsesWith(newVec);
59+
return true;
60+
}
61+
62+
// This Pass scalarizes vector `fmul`s and `fadd`s in basic blocks that contain
63+
// MFMAs. The point/purpose/value of doing is that these get codegened to
64+
// "packed" ops (`v_pk_mul_f32`/`v_pk_add_f32`) and while packed ops use
65+
// separate VALUs from MFMA tensor cores (no problem there), the instructions
66+
// themselves cannot be *issued* in parallel, thus there is a performance cost
67+
// to having such packed ops "near" MFMAs. Concretely/specifically this
68+
// eliminates `v_pk_mul_f32`/`v_pk_add_f32` operations in the final asm in bbs
69+
// with MFMAs.
70+
//
71+
// Note, these "scalar" floating point ops will still get lowered to vector
72+
// instructions like `v_mul_f32_e32 v1, v163, v114` and
73+
// `v_add_u32_e32 v1, s16, v12`, just not the "packed" variants.
74+
//
75+
// Note, these vectorized `fmul`s aren't actually emitted by triton per se -
76+
// they are introduced/inserted by the VectorCombine::foldPermuteOfBinops
77+
// pattern during the `optimize_module` pipeline (hence why this LLVM pass
78+
// needs to follow that pipeline).
79+
struct ScalarizePackedFOps : FunctionPass {
80+
ScalarizePackedFOps() : FunctionPass(ID) {}
81+
82+
bool runOnFunction(Function &F) override {
83+
IRBuilder<> builder(F.getContext());
84+
bool changed = false;
85+
SmallVector<Instruction *> instsToErase;
86+
for (BasicBlock &BB : F) {
87+
if (!llvm::any_of(BB, isMFMAorWMMA))
88+
continue;
89+
for (Instruction &inst : BB) {
90+
if (inst.getOpcode() != Instruction::FMul &&
91+
inst.getOpcode() != Instruction::FAdd &&
92+
inst.getOpcode() != Instruction::FSub)
93+
continue;
94+
if (maybeReplaceVectorFOpWithScalarFOps(&inst, builder)) {
95+
instsToErase.push_back(&inst);
96+
changed = true;
97+
}
98+
}
99+
}
100+
101+
if (changed) {
102+
for (Instruction *inst : instsToErase) {
103+
if (inst)
104+
inst->eraseFromParent();
105+
}
106+
}
107+
108+
// We don't do anything with this but this is a virtual function override
109+
// and the signature requires it.
110+
return changed;
111+
}
112+
113+
static char ID;
114+
};
115+
116+
} // end anonymous namespace
117+
118+
char ScalarizePackedFOps::ID = 0;
119+
120+
namespace mlir::triton::AMD {
121+
void runScalarizePackedFOpsPass(Function &F) {
122+
ScalarizePackedFOps pass;
123+
pass.runOnFunction(F);
124+
// If there are no errors, the function returns false.
125+
assert(!llvm::verifyFunction(F) &&
126+
"expected function to verify successfully");
127+
}
128+
} // namespace mlir::triton::AMD

0 commit comments

Comments
 (0)