Skip to content

Commit aff9068

Browse files
committed
[Matrix] Lower vector reductions using shape info
1 parent 84ff8f4 commit aff9068

File tree

2 files changed

+191
-37
lines changed

2 files changed

+191
-37
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/Analysis/ValueTracking.h"
3333
#include "llvm/Analysis/VectorUtils.h"
3434
#include "llvm/IR/CFG.h"
35+
#include "llvm/IR/Constants.h"
3536
#include "llvm/IR/DataLayout.h"
3637
#include "llvm/IR/DebugInfoMetadata.h"
3738
#include "llvm/IR/DerivedTypes.h"
@@ -41,6 +42,7 @@
4142
#include "llvm/IR/Instructions.h"
4243
#include "llvm/IR/IntrinsicInst.h"
4344
#include "llvm/IR/MatrixBuilder.h"
45+
#include "llvm/IR/Operator.h"
4446
#include "llvm/IR/PatternMatch.h"
4547
#include "llvm/Support/Alignment.h"
4648
#include "llvm/Support/CommandLine.h"
@@ -325,6 +327,25 @@ computeShapeInfoForInst(Instruction *I,
325327
return OpShape->second;
326328
}
327329

330+
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
331+
switch (II->getIntrinsicID()) {
332+
case Intrinsic::vector_reduce_fadd:
333+
case Intrinsic::vector_reduce_fmul:
334+
case Intrinsic::vector_reduce_fmax:
335+
case Intrinsic::vector_reduce_fmaximum:
336+
case Intrinsic::vector_reduce_fmin:
337+
case Intrinsic::vector_reduce_fminimum:
338+
case Intrinsic::vector_reduce_add:
339+
case Intrinsic::vector_reduce_and:
340+
case Intrinsic::vector_reduce_mul:
341+
case Intrinsic::vector_reduce_or:
342+
case Intrinsic::vector_reduce_xor:
343+
return ShapeInfo(1, 1);
344+
default:
345+
break;
346+
}
347+
}
348+
328349
if (isUniformShape(I) || isa<SelectInst>(I)) {
329350
auto Ops = I->operands();
330351
auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
@@ -468,7 +489,7 @@ class LowerMatrixIntrinsics {
468489
return make_range(Vectors.begin(), Vectors.end());
469490
}
470491

471-
iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
492+
iterator_range<SmallVector<Value *, 8>::const_iterator> vectors() const {
472493
return make_range(Vectors.begin(), Vectors.end());
473494
}
474495

@@ -701,7 +722,31 @@ class LowerMatrixIntrinsics {
701722
case Intrinsic::matrix_transpose:
702723
case Intrinsic::matrix_column_major_load:
703724
case Intrinsic::matrix_column_major_store:
725+
case Intrinsic::vector_reduce_fmax:
726+
case Intrinsic::vector_reduce_fmaximum:
727+
case Intrinsic::vector_reduce_fmin:
728+
case Intrinsic::vector_reduce_fminimum:
729+
case Intrinsic::vector_reduce_add:
730+
case Intrinsic::vector_reduce_and:
731+
case Intrinsic::vector_reduce_mul:
732+
case Intrinsic::vector_reduce_or:
733+
case Intrinsic::vector_reduce_xor:
704734
return true;
735+
case Intrinsic::vector_reduce_fadd:
736+
case Intrinsic::vector_reduce_fmul: {
737+
FastMathFlags FMF = getFastMathFlags(Inst);
738+
if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
739+
return false;
740+
741+
if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fadd>(
742+
m_Unless(m_AnyZeroFP()), m_Value())))
743+
return false;
744+
745+
if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fmul>(
746+
m_Unless(m_FPOne()), m_Value())))
747+
return false;
748+
return true;
749+
}
705750
default:
706751
return isUniformShape(II);
707752
}
@@ -1268,6 +1313,113 @@ class LowerMatrixIntrinsics {
12681313
return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
12691314
Result.getNumVectors());
12701315
}
1316+
case Intrinsic::vector_reduce_fadd: {
1317+
Builder.setFastMathFlags(getFastMathFlags(Inst));
1318+
auto *I = Inst2ColumnMatrix.find(Inst->getOperand(1));
1319+
assert(I != Inst2ColumnMatrix.end());
1320+
const MatrixTy &M = I->second;
1321+
1322+
Value *Start = Inst->getOperand(0);
1323+
Value *ResultV = Builder.CreateVectorSplat(
1324+
ElementCount::getFixed(M.getStride()), Start);
1325+
for (auto &Vector : M.vectors())
1326+
ResultV = Builder.CreateFAdd(ResultV, Vector);
1327+
1328+
Value *Result = Builder.CreateFAddReduce(Start, ResultV);
1329+
Inst->replaceAllUsesWith(Result);
1330+
Result->takeName(Inst);
1331+
return MatrixTy{Result};
1332+
} break;
1333+
case Intrinsic::vector_reduce_fmul: {
1334+
Builder.setFastMathFlags(getFastMathFlags(Inst));
1335+
auto *I = Inst2ColumnMatrix.find(Inst->getOperand(1));
1336+
assert(I != Inst2ColumnMatrix.end());
1337+
const MatrixTy &M = I->second;
1338+
1339+
Value *Start = Inst->getOperand(0);
1340+
Value *ResultV = Builder.CreateVectorSplat(
1341+
ElementCount::getFixed(M.getStride()), Start);
1342+
for (auto &Vector : M.vectors())
1343+
ResultV = Builder.CreateFMul(ResultV, Vector);
1344+
1345+
Value *Result = Builder.CreateFMulReduce(Start, ResultV);
1346+
Inst->replaceAllUsesWith(Result);
1347+
Result->takeName(Inst);
1348+
return MatrixTy{Result};
1349+
} break;
1350+
case Intrinsic::vector_reduce_fmax:
1351+
case Intrinsic::vector_reduce_fmaximum:
1352+
case Intrinsic::vector_reduce_fmin:
1353+
case Intrinsic::vector_reduce_fminimum:
1354+
case Intrinsic::vector_reduce_add:
1355+
case Intrinsic::vector_reduce_and:
1356+
case Intrinsic::vector_reduce_mul:
1357+
case Intrinsic::vector_reduce_or:
1358+
case Intrinsic::vector_reduce_xor: {
1359+
Builder.setFastMathFlags(getFastMathFlags(Inst));
1360+
auto *I = Inst2ColumnMatrix.find(Inst->getOperand(0));
1361+
assert(I != Inst2ColumnMatrix.end());
1362+
const MatrixTy &M = I->second;
1363+
1364+
auto CreateVReduce = [&](Value *LHS, Value *RHS) {
1365+
switch (Inst->getIntrinsicID()) {
1366+
case Intrinsic::vector_reduce_add:
1367+
return Builder.CreateAdd(LHS, RHS);
1368+
case Intrinsic::vector_reduce_and:
1369+
return Builder.CreateAnd(LHS, RHS);
1370+
case Intrinsic::vector_reduce_fmax:
1371+
return Builder.CreateMaximum(LHS, RHS);
1372+
case Intrinsic::vector_reduce_fmaximum:
1373+
return Builder.CreateMaximumNum(LHS, RHS);
1374+
case Intrinsic::vector_reduce_fmin:
1375+
return Builder.CreateMinimum(LHS, RHS);
1376+
case Intrinsic::vector_reduce_fminimum:
1377+
return Builder.CreateMinimumNum(LHS, RHS);
1378+
case Intrinsic::vector_reduce_mul:
1379+
return Builder.CreateMul(LHS, RHS);
1380+
case Intrinsic::vector_reduce_or:
1381+
return Builder.CreateOr(LHS, RHS);
1382+
case Intrinsic::vector_reduce_xor:
1383+
return Builder.CreateXor(LHS, RHS);
1384+
default:
1385+
llvm_unreachable("unexpected intrinsic");
1386+
}
1387+
};
1388+
1389+
Value *ResultV = M.getVector(0);
1390+
for (auto &Vector : drop_begin(M.vectors()))
1391+
ResultV = CreateVReduce(ResultV, Vector);
1392+
1393+
auto CreateHReduce = [&](Value *V) {
1394+
switch (Inst->getIntrinsicID()) {
1395+
case Intrinsic::vector_reduce_add:
1396+
return Builder.CreateAddReduce(V);
1397+
case Intrinsic::vector_reduce_and:
1398+
return Builder.CreateAndReduce(V);
1399+
case Intrinsic::vector_reduce_fmax:
1400+
return Builder.CreateFPMaxReduce(V);
1401+
case Intrinsic::vector_reduce_fmaximum:
1402+
return Builder.CreateFPMaximumReduce(V);
1403+
case Intrinsic::vector_reduce_fmin:
1404+
return Builder.CreateFPMinReduce(V);
1405+
case Intrinsic::vector_reduce_fminimum:
1406+
return Builder.CreateFPMinimumReduce(V);
1407+
case Intrinsic::vector_reduce_mul:
1408+
return Builder.CreateMulReduce(V);
1409+
case Intrinsic::vector_reduce_or:
1410+
return Builder.CreateOrReduce(V);
1411+
case Intrinsic::vector_reduce_xor:
1412+
return Builder.CreateXorReduce(V);
1413+
default:
1414+
llvm_unreachable("unexpected intrinsic");
1415+
}
1416+
};
1417+
1418+
Value *Result = CreateHReduce(ResultV);
1419+
Inst->replaceAllUsesWith(Result);
1420+
Result->takeName(Inst);
1421+
return MatrixTy{Result};
1422+
} break;
12711423
default:
12721424
break;
12731425
}

0 commit comments

Comments
 (0)