Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b1b79aa
[Matrix] Propagate shape information through PHI instructions
jroelofs May 27, 2025
71b99d3
move formerly unsupported test to new home
jroelofs May 27, 2025
905c1e9
clang-format
jroelofs May 27, 2025
9ee44f0
add test for ConstantDataVector lowering
jroelofs May 27, 2025
169960d
move report_fatal_error outside of NDEBUG block
jroelofs May 27, 2025
b64a134
Merge branch 'main' into jroelofs/lower-matrix-phi
jroelofs May 28, 2025
18951cd
fix bad merge
jroelofs May 31, 2025
f8aea05
use col major load intrinsics
jroelofs Jun 2, 2025
e56b225
add tests for phi's consuming phi's, and phi's with more than two inputs
jroelofs Jun 2, 2025
ffbc73f
handle phi's more like other ops. instcombine will clean up after us
jroelofs Jun 2, 2025
15fd60b
handle phi's with shape mismatch
jroelofs Jun 2, 2025
88bd8cb
Merge branch 'main' into jroelofs/lower-matrix-phi
jroelofs Jun 2, 2025
655eb88
simplify getMatrix shim
jroelofs Jun 2, 2025
e262f76
test the other order of shape mismatch
jroelofs Jun 2, 2025
2c86c2f
clang-format
jroelofs Jun 9, 2025
86d3545
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 10, 2025
501414f
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 10, 2025
2e5b2d4
clang-format
jroelofs Jun 11, 2025
f048181
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 11, 2025
7511d17
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 12, 2025
2821467
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 12, 2025
4ba4e66
[Matrix] Fix a crash in VisitSelectInst due to iteration length mismatch
jroelofs Jun 12, 2025
4cbc839
review feedback: parens for initializer
jroelofs Jun 16, 2025
6f8ec49
review feedback: rename to GetMatrix
jroelofs Jun 16, 2025
104c126
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 16, 2025
67ead37
drop code for splitting constants, add test for it
jroelofs Jun 16, 2025
c9b3992
split phi's in two phases
jroelofs Jun 16, 2025
da17d10
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 16, 2025
dd55682
clang-format
jroelofs Jun 16, 2025
08be3b4
Merge branch 'main' into jroelofs/lower-matrix-phi
jroelofs Jun 17, 2025
5a991f7
rm constant.ll
jroelofs Jun 17, 2025
e9d0e62
test that shows reshape shuffles are inserted in the correct spot
jroelofs Jun 18, 2025
a760840
florian's suggestion is a little simpler: we already know it's a phi
jroelofs Jun 18, 2025
a62ef50
also use getInsertionPointAtDef() for non-inst phi operand reshape pl…
jroelofs Jun 18, 2025
47dd2e2
inline the lambda shim, to simplify
jroelofs Jun 18, 2025
0debc08
rm unnecessary include
jroelofs Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 129 additions & 45 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
Expand Down Expand Up @@ -288,6 +289,7 @@ static bool isUniformShape(Value *V) {
}

switch (I->getOpcode()) {
case Instruction::PHI:
case Instruction::FNeg:
return true;
default:
Expand Down Expand Up @@ -414,6 +416,33 @@ class LowerMatrixIntrinsics {
addVector(PoisonValue::get(FixedVectorType::get(
EltTy, isColumnMajor() ? NumRows : NumColumns)));
}
MatrixTy(ConstantData *Constant, const ShapeInfo &SI)
: IsColumnMajor(SI.IsColumnMajor) {
Type *EltTy = cast<VectorType>(Constant->getType())->getElementType();
Type *RowTy = VectorType::get(EltTy, ElementCount::getFixed(SI.NumRows));

for (unsigned J = 0, D = SI.getNumVectors(); J < D; ++J) {
if (auto *CDV = dyn_cast<ConstantDataVector>(Constant)) {
unsigned Width = SI.getStride();
size_t EltSize = EltTy->getPrimitiveSizeInBits() / 8;
StringRef Data = CDV->getRawDataValues().substr(J * Width * EltSize,
Width * EltSize);
addVector(
ConstantDataVector::getRaw(Data, Width, CDV->getElementType()));
} else if (isa<PoisonValue>(Constant))
addVector(PoisonValue::get(RowTy));
else if (isa<UndefValue>(Constant))
addVector(UndefValue::get(RowTy));
else if (isa<ConstantAggregateZero>(Constant))
addVector(ConstantAggregateZero::get(RowTy));
else {
#ifndef NDEBUG
Constant->dump();
#endif
report_fatal_error("unhandled ConstantData type");
}
}
}

Value *getVector(unsigned i) const { return Vectors[i]; }
Value *getColumn(unsigned i) const {
Expand Down Expand Up @@ -618,6 +647,10 @@ class LowerMatrixIntrinsics {
MatrixVal = M.embedInVector(Builder);
}

// If it's a constant, materialize the split version of it with this shape.
if (auto *IncomingConst = dyn_cast<ConstantData>(MatrixVal))
return MatrixTy(IncomingConst, SI);

// Otherwise split MatrixVal.
SmallVector<Value *, 16> SplitVecs;
for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
Expand Down Expand Up @@ -1146,24 +1179,26 @@ class LowerMatrixIntrinsics {
Value *Op1;
Value *Op2;
MatrixTy Result;
IRBuilder<> Builder(Inst);
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
Result = VisitBinaryOperator(BinOp, SI);
Result = VisitBinaryOperator(BinOp, SI, Builder);
else if (auto *Cast = dyn_cast<CastInst>(Inst))
Result = VisitCastInstruction(Cast, SI);
Result = VisitCastInstruction(Cast, SI, Builder);
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
Result = VisitUnaryOperator(UnOp, SI);
Result = VisitUnaryOperator(UnOp, SI, Builder);
else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
Result = VisitIntrinsicInst(Intr, SI);
Result = VisitIntrinsicInst(Intr, SI, Builder);
else if (auto *Select = dyn_cast<SelectInst>(Inst))
Result = VisitSelectInst(Select, SI);
Result = VisitSelectInst(Select, SI, Builder);
else if (match(Inst, m_Load(m_Value(Op1))))
Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1);
Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2);
Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
else if (auto *PHI = dyn_cast<PHINode>(Inst))
Result = VisitPHI(PHI, SI, Builder);
else
continue;

IRBuilder<> Builder(Inst);
finalizeLowering(Inst, Result, Builder);
Changed = true;
}
Expand Down Expand Up @@ -1204,22 +1239,22 @@ class LowerMatrixIntrinsics {
}

/// Replace intrinsic calls.
MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI) {
MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI,
IRBuilder<> &Builder) {
assert(Inst->getCalledFunction() &&
Inst->getCalledFunction()->isIntrinsic());

switch (Inst->getCalledFunction()->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
return LowerMultiply(Inst);
return LowerMultiply(Inst, Builder);
case Intrinsic::matrix_transpose:
return LowerTranspose(Inst);
return LowerTranspose(Inst, Builder);
case Intrinsic::matrix_column_major_load:
return LowerColumnMajorLoad(Inst);
return LowerColumnMajorLoad(Inst, Builder);
case Intrinsic::matrix_column_major_store:
return LowerColumnMajorStore(Inst);
return LowerColumnMajorStore(Inst, Builder);
case Intrinsic::abs:
case Intrinsic::fabs: {
IRBuilder<> Builder(Inst);
MatrixTy Result;
MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder);
Builder.setFastMathFlags(getFastMathFlags(Inst));
Expand Down Expand Up @@ -1313,23 +1348,23 @@ class LowerMatrixIntrinsics {

/// Lower a load instruction with shape information.
MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align,
Value *Stride, bool IsVolatile, ShapeInfo Shape) {
IRBuilder<> Builder(Inst);
Value *Stride, bool IsVolatile, ShapeInfo Shape,
IRBuilder<> &Builder) {
return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape,
Builder);
}

/// Lowers llvm.matrix.column.major.load.
///
/// The intrinsic loads a matrix from memory using a stride between columns.
MatrixTy LowerColumnMajorLoad(CallInst *Inst) {
MatrixTy LowerColumnMajorLoad(CallInst *Inst, IRBuilder<> &Builder) {
assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
"Intrinsic only supports column-major layout!");
Value *Ptr = Inst->getArgOperand(0);
Value *Stride = Inst->getArgOperand(1);
return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
{Inst->getArgOperand(3), Inst->getArgOperand(4)});
{Inst->getArgOperand(3), Inst->getArgOperand(4)}, Builder);
}

/// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
Expand Down Expand Up @@ -1374,8 +1409,7 @@ class LowerMatrixIntrinsics {
/// Lower a store instruction with shape information.
MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr,
MaybeAlign A, Value *Stride, bool IsVolatile,
ShapeInfo Shape) {
IRBuilder<> Builder(Inst);
ShapeInfo Shape, IRBuilder<> &Builder) {
auto StoreVal = getMatrix(Matrix, Shape, Builder);
return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile,
Builder);
Expand All @@ -1384,15 +1418,16 @@ class LowerMatrixIntrinsics {
/// Lowers llvm.matrix.column.major.store.
///
/// The intrinsic store a matrix back memory using a stride between columns.
MatrixTy LowerColumnMajorStore(CallInst *Inst) {
MatrixTy LowerColumnMajorStore(CallInst *Inst, IRBuilder<> &Builder) {
assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
"Intrinsic only supports column-major layout!");
Value *Matrix = Inst->getArgOperand(0);
Value *Ptr = Inst->getArgOperand(1);
Value *Stride = Inst->getArgOperand(2);
return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
{Inst->getArgOperand(4), Inst->getArgOperand(5)});
{Inst->getArgOperand(4), Inst->getArgOperand(5)},
Builder);
}

// Set elements I..I+NumElts-1 to Block
Expand Down Expand Up @@ -1459,7 +1494,8 @@ class LowerMatrixIntrinsics {
IRBuilder<> &Builder) {
auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
(void)inserted;
assert(inserted.second && "multiple matrix lowering mapping");
assert((inserted.second || isa<PHINode>(Inst)) &&
"multiple matrix lowering mapping");

ToRemove.push_back(Inst);
Value *Flattened = nullptr;
Expand Down Expand Up @@ -2167,8 +2203,7 @@ class LowerMatrixIntrinsics {
}

/// Lowers llvm.matrix.multiply.
MatrixTy LowerMultiply(CallInst *MatMul) {
IRBuilder<> Builder(MatMul);
MatrixTy LowerMultiply(CallInst *MatMul, IRBuilder<> &Builder) {
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
Expand All @@ -2193,9 +2228,8 @@ class LowerMatrixIntrinsics {
}

/// Lowers llvm.matrix.transpose.
MatrixTy LowerTranspose(CallInst *Inst) {
MatrixTy LowerTranspose(CallInst *Inst, IRBuilder<> &Builder) {
MatrixTy Result;
IRBuilder<> Builder(Inst);
Value *InputVal = Inst->getArgOperand(0);
FixedVectorType *VectorTy = cast<FixedVectorType>(InputVal->getType());
ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
Expand Down Expand Up @@ -2228,26 +2262,80 @@ class LowerMatrixIntrinsics {
}

/// Lower load instructions.
MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) {
IRBuilder<> Builder(Inst);
MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
IRBuilder<> &Builder) {
return LowerLoad(Inst, Ptr, Inst->getAlign(),
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
Builder);
}

MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
Value *Ptr) {
IRBuilder<> Builder(Inst);
Value *Ptr, IRBuilder<> &Builder) {
return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
Builder);
}

MatrixTy VisitPHI(PHINode *Inst, const ShapeInfo &SI, IRBuilder<> &Builder) {
// Shim this->getMatrix to insert split phi's as needed.
auto getMatrix = [this, &Builder, SI](Value *MatrixVal) -> MatrixTy {
IRBuilder<>::InsertPointGuard IPG(Builder);

auto I = Inst2ColumnMatrix.find(MatrixVal);
if (I == Inst2ColumnMatrix.end()) {
if (auto *PHI = dyn_cast<PHINode>(MatrixVal)) {
auto *EltTy = cast<VectorType>(PHI->getType())->getElementType();
MatrixTy PhiM{SI.NumRows, SI.NumColumns, EltTy};

Builder.SetInsertPoint(PHI);
for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),
PHI->getNumIncomingValues(),
PHI->getName()));

Inst2ColumnMatrix[PHI] = PhiM;
}
}

// getMatrix() may insert some instructions for reshaping. The safe place
// to insert them is at the end of the parent block, where the register
// allocator would have inserted the copies that materialize the PHI.
if (auto *MatrixInst = dyn_cast<Instruction>(MatrixVal))
Builder.SetInsertPoint(MatrixInst->getParent()->getTerminator());

return this->getMatrix(MatrixVal, SI, Builder);
};

MatrixTy PhiM = getMatrix(Inst);

for (unsigned IncomingI = 0, IncomingE = Inst->getNumIncomingValues();
IncomingI != IncomingE; ++IncomingI) {
Value *IncomingV = Inst->getIncomingValue(IncomingI);
BasicBlock *IncomingB = Inst->getIncomingBlock(IncomingI);

MatrixTy OpM = getMatrix(IncomingV);

for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {
PHINode *NewPHI = cast<PHINode>(PhiM.getVector(VI));
NewPHI->addIncoming(OpM.getVector(VI), IncomingB);
}
}

// finalizeLowering() may also insert instructions in some cases. The safe
// place for those is at the end of the initial block of PHIs.
auto IP = Inst->getInsertionPointAfterDef();
assert(IP.has_value() &&
"expected to find a valid insertion point after the phi");
Builder.SetInsertPoint(*IP);
return PhiM;
}

/// Lower binary operators.
MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI,
IRBuilder<> &Builder) {
Value *Lhs = Inst->getOperand(0);
Value *Rhs = Inst->getOperand(1);

IRBuilder<> Builder(Inst);

MatrixTy Result;
MatrixTy A = getMatrix(Lhs, SI, Builder);
MatrixTy B = getMatrix(Rhs, SI, Builder);
Expand All @@ -2265,11 +2353,10 @@ class LowerMatrixIntrinsics {
}

/// Lower unary operators.
MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI,
IRBuilder<> &Builder) {
Value *Op = Inst->getOperand(0);

IRBuilder<> Builder(Inst);

MatrixTy Result;
MatrixTy M = getMatrix(Op, SI, Builder);

Expand All @@ -2293,11 +2380,10 @@ class LowerMatrixIntrinsics {
}

/// Lower cast instructions.
MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) {
MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape,
IRBuilder<> &Builder) {
Value *Op = Inst->getOperand(0);

IRBuilder<> Builder(Inst);

MatrixTy Result;
MatrixTy M = getMatrix(Op, Shape, Builder);

Expand All @@ -2315,13 +2401,11 @@ class LowerMatrixIntrinsics {
}

/// Lower selects.
MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) {
MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape, IRBuilder<> &Builder) {
Value *Cond = Inst->getOperand(0);
Value *OpA = Inst->getOperand(1);
Value *OpB = Inst->getOperand(2);

IRBuilder<> Builder(Inst);

MatrixTy Result;
MatrixTy A = getMatrix(OpA, Shape, Builder);
MatrixTy B = getMatrix(OpB, Shape, Builder);
Expand Down
Loading
Loading