Skip to content

Commit e96975e

Browse files
committed
hoist finalizeLowering into caller
1 parent 7b2ac8f commit e96975e

File tree

1 file changed

+55
-65
lines changed

1 file changed

+55
-65
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 55 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,24 +1054,26 @@ class LowerMatrixIntrinsics {
10541054
if (FusedInsts.count(Inst))
10551055
continue;
10561056

1057-
IRBuilder<> Builder(Inst);
1058-
10591057
const ShapeInfo &SI = ShapeMap.at(Inst);
10601058

10611059
Value *Op1;
10621060
Value *Op2;
1061+
MatrixTy Result;
10631062
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1064-
VisitBinaryOperator(BinOp, SI);
1063+
Result = VisitBinaryOperator(BinOp, SI);
10651064
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1066-
VisitUnaryOperator(UnOp, SI);
1067-
else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1068-
VisitCallInst(CInst);
1065+
Result = VisitUnaryOperator(UnOp, SI);
1066+
else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
1067+
Result = VisitIntrinsicInst(Intr, SI);
10691068
else if (match(Inst, m_Load(m_Value(Op1))))
1070-
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
1069+
Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1);
10711070
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1072-
VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
1071+
Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2);
10731072
else
10741073
continue;
1074+
1075+
IRBuilder<> Builder(Inst);
1076+
finalizeLowering(Inst, Result, Builder);
10751077
Changed = true;
10761078
}
10771079

@@ -1111,27 +1113,24 @@ class LowerMatrixIntrinsics {
11111113
}
11121114

11131115
/// Replace intrinsic calls.
1114-
void VisitCallInst(CallInst *Inst) {
1116+
MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI) {
11151117
assert(Inst->getCalledFunction() &&
11161118
Inst->getCalledFunction()->isIntrinsic());
11171119

11181120
switch (Inst->getCalledFunction()->getIntrinsicID()) {
11191121
case Intrinsic::matrix_multiply:
1120-
LowerMultiply(Inst);
1121-
break;
1122+
return LowerMultiply(Inst);
11221123
case Intrinsic::matrix_transpose:
1123-
LowerTranspose(Inst);
1124-
break;
1124+
return LowerTranspose(Inst);
11251125
case Intrinsic::matrix_column_major_load:
1126-
LowerColumnMajorLoad(Inst);
1127-
break;
1126+
return LowerColumnMajorLoad(Inst);
11281127
case Intrinsic::matrix_column_major_store:
1129-
LowerColumnMajorStore(Inst);
1130-
break;
1128+
return LowerColumnMajorStore(Inst);
11311129
default:
1132-
llvm_unreachable(
1133-
"only intrinsics supporting shape info should be seen here");
1130+
break;
11341131
}
1132+
llvm_unreachable(
1133+
"only intrinsics supporting shape info should be seen here");
11351134
}
11361135

11371136
/// Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -1197,26 +1196,24 @@ class LowerMatrixIntrinsics {
11971196
}
11981197

11991198
/// Lower a load instruction with shape information.
1200-
void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1201-
bool IsVolatile, ShapeInfo Shape) {
1199+
MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align,
1200+
Value *Stride, bool IsVolatile, ShapeInfo Shape) {
12021201
IRBuilder<> Builder(Inst);
1203-
finalizeLowering(Inst,
1204-
loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1205-
Shape, Builder),
1206-
Builder);
1202+
return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape,
1203+
Builder);
12071204
}
12081205

12091206
/// Lowers llvm.matrix.column.major.load.
12101207
///
12111208
/// The intrinsic loads a matrix from memory using a stride between columns.
1212-
void LowerColumnMajorLoad(CallInst *Inst) {
1209+
MatrixTy LowerColumnMajorLoad(CallInst *Inst) {
12131210
assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
12141211
"Intrinsic only supports column-major layout!");
12151212
Value *Ptr = Inst->getArgOperand(0);
12161213
Value *Stride = Inst->getArgOperand(1);
1217-
LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1218-
cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1219-
{Inst->getArgOperand(3), Inst->getArgOperand(4)});
1214+
return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1215+
cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1216+
{Inst->getArgOperand(3), Inst->getArgOperand(4)});
12201217
}
12211218

12221219
/// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
@@ -1259,28 +1256,27 @@ class LowerMatrixIntrinsics {
12591256
}
12601257

12611258
/// Lower a store instruction with shape information.
1262-
void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1263-
Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1259+
MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr,
1260+
MaybeAlign A, Value *Stride, bool IsVolatile,
1261+
ShapeInfo Shape) {
12641262
IRBuilder<> Builder(Inst);
12651263
auto StoreVal = getMatrix(Matrix, Shape, Builder);
1266-
finalizeLowering(Inst,
1267-
storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1268-
IsVolatile, Builder),
1269-
Builder);
1264+
return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile,
1265+
Builder);
12701266
}
12711267

12721268
/// Lowers llvm.matrix.column.major.store.
12731269
///
12741270
/// The intrinsic store a matrix back memory using a stride between columns.
1275-
void LowerColumnMajorStore(CallInst *Inst) {
1271+
MatrixTy LowerColumnMajorStore(CallInst *Inst) {
12761272
assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
12771273
"Intrinsic only supports column-major layout!");
12781274
Value *Matrix = Inst->getArgOperand(0);
12791275
Value *Ptr = Inst->getArgOperand(1);
12801276
Value *Stride = Inst->getArgOperand(2);
1281-
LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1282-
cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1283-
{Inst->getArgOperand(4), Inst->getArgOperand(5)});
1277+
return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1278+
cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1279+
{Inst->getArgOperand(4), Inst->getArgOperand(5)});
12841280
}
12851281

12861282
// Set elements I..I+NumElts-1 to Block
@@ -2045,7 +2041,7 @@ class LowerMatrixIntrinsics {
20452041
}
20462042

20472043
/// Lowers llvm.matrix.multiply.
2048-
void LowerMultiply(CallInst *MatMul) {
2044+
MatrixTy LowerMultiply(CallInst *MatMul) {
20492045
IRBuilder<> Builder(MatMul);
20502046
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
20512047
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
@@ -2067,11 +2063,11 @@ class LowerMatrixIntrinsics {
20672063

20682064
emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
20692065
getFastMathFlags(MatMul));
2070-
finalizeLowering(MatMul, Result, Builder);
2066+
return Result;
20712067
}
20722068

20732069
/// Lowers llvm.matrix.transpose.
2074-
void LowerTranspose(CallInst *Inst) {
2070+
MatrixTy LowerTranspose(CallInst *Inst) {
20752071
MatrixTy Result;
20762072
IRBuilder<> Builder(Inst);
20772073
Value *InputVal = Inst->getArgOperand(0);
@@ -2101,28 +2097,26 @@ class LowerMatrixIntrinsics {
21012097
// TODO: Improve estimate of operations needed for transposes. Currently we
21022098
// just count the insertelement/extractelement instructions, but do not
21032099
// account for later simplifications/combines.
2104-
finalizeLowering(
2105-
Inst,
2106-
Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2107-
.addNumExposedTransposes(1),
2108-
Builder);
2100+
return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2101+
.addNumExposedTransposes(1);
21092102
}
21102103

21112104
/// Lower load instructions.
2112-
void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2113-
IRBuilder<> &Builder) {
2114-
LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()),
2115-
Inst->isVolatile(), SI);
2105+
MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) {
2106+
IRBuilder<> Builder(Inst);
2107+
return LowerLoad(Inst, Ptr, Inst->getAlign(),
2108+
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
21162109
}
21172110

2118-
void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2119-
Value *Ptr, IRBuilder<> &Builder) {
2120-
LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2121-
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
2111+
MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2112+
Value *Ptr) {
2113+
IRBuilder<> Builder(Inst);
2114+
return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2115+
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
21222116
}
21232117

21242118
/// Lower binary operators.
2125-
void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
2119+
MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
21262120
Value *Lhs = Inst->getOperand(0);
21272121
Value *Rhs = Inst->getOperand(1);
21282122

@@ -2141,14 +2135,12 @@ class LowerMatrixIntrinsics {
21412135
Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
21422136
B.getVector(I)));
21432137

2144-
finalizeLowering(Inst,
2145-
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2146-
Result.getNumVectors()),
2147-
Builder);
2138+
return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2139+
Result.getNumVectors());
21482140
}
21492141

21502142
/// Lower unary operators.
2151-
void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
2143+
MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
21522144
Value *Op = Inst->getOperand(0);
21532145

21542146
IRBuilder<> Builder(Inst);
@@ -2171,10 +2163,8 @@ class LowerMatrixIntrinsics {
21712163
for (unsigned I = 0; I < SI.getNumVectors(); ++I)
21722164
Result.addVector(BuildVectorOp(M.getVector(I)));
21732165

2174-
finalizeLowering(Inst,
2175-
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2176-
Result.getNumVectors()),
2177-
Builder);
2166+
return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2167+
Result.getNumVectors());
21782168
}
21792169

21802170
/// Helper to linearize a matrix expression tree into a string. Currently

0 commit comments

Comments
 (0)