|
32 | 32 | #include "llvm/Analysis/ValueTracking.h" |
33 | 33 | #include "llvm/Analysis/VectorUtils.h" |
34 | 34 | #include "llvm/IR/CFG.h" |
| 35 | +#include "llvm/IR/Constants.h" |
35 | 36 | #include "llvm/IR/DataLayout.h" |
36 | 37 | #include "llvm/IR/DebugInfoMetadata.h" |
37 | 38 | #include "llvm/IR/DerivedTypes.h" |
|
41 | 42 | #include "llvm/IR/Instructions.h" |
42 | 43 | #include "llvm/IR/IntrinsicInst.h" |
43 | 44 | #include "llvm/IR/MatrixBuilder.h" |
| 45 | +#include "llvm/IR/Operator.h" |
44 | 46 | #include "llvm/IR/PatternMatch.h" |
45 | 47 | #include "llvm/Support/Alignment.h" |
46 | 48 | #include "llvm/Support/CommandLine.h" |
@@ -325,6 +327,25 @@ computeShapeInfoForInst(Instruction *I, |
325 | 327 | return OpShape->second; |
326 | 328 | } |
327 | 329 |
|
| 330 | + if (auto *II = dyn_cast<IntrinsicInst>(I)) { |
| 331 | + switch (II->getIntrinsicID()) { |
| 332 | + case Intrinsic::vector_reduce_fadd: |
| 333 | + case Intrinsic::vector_reduce_fmul: |
| 334 | + case Intrinsic::vector_reduce_fmax: |
| 335 | + case Intrinsic::vector_reduce_fmaximum: |
| 336 | + case Intrinsic::vector_reduce_fmin: |
| 337 | + case Intrinsic::vector_reduce_fminimum: |
| 338 | + case Intrinsic::vector_reduce_add: |
| 339 | + case Intrinsic::vector_reduce_and: |
| 340 | + case Intrinsic::vector_reduce_mul: |
| 341 | + case Intrinsic::vector_reduce_or: |
| 342 | + case Intrinsic::vector_reduce_xor: |
| 343 | + return ShapeInfo(1, 1); |
| 344 | + default: |
| 345 | + break; |
| 346 | + } |
| 347 | + } |
| 348 | + |
328 | 349 | if (isUniformShape(I) || isa<SelectInst>(I)) { |
329 | 350 | auto Ops = I->operands(); |
330 | 351 | auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops; |
@@ -468,7 +489,7 @@ class LowerMatrixIntrinsics { |
468 | 489 | return make_range(Vectors.begin(), Vectors.end()); |
469 | 490 | } |
470 | 491 |
|
471 | | - iterator_range<SmallVector<Value *, 8>::iterator> vectors() { |
| 492 | + iterator_range<SmallVector<Value *, 8>::const_iterator> vectors() const { |
472 | 493 | return make_range(Vectors.begin(), Vectors.end()); |
473 | 494 | } |
474 | 495 |
|
@@ -701,7 +722,31 @@ class LowerMatrixIntrinsics { |
701 | 722 | case Intrinsic::matrix_transpose: |
702 | 723 | case Intrinsic::matrix_column_major_load: |
703 | 724 | case Intrinsic::matrix_column_major_store: |
| 725 | + case Intrinsic::vector_reduce_fmax: |
| 726 | + case Intrinsic::vector_reduce_fmaximum: |
| 727 | + case Intrinsic::vector_reduce_fmin: |
| 728 | + case Intrinsic::vector_reduce_fminimum: |
| 729 | + case Intrinsic::vector_reduce_add: |
| 730 | + case Intrinsic::vector_reduce_and: |
| 731 | + case Intrinsic::vector_reduce_mul: |
| 732 | + case Intrinsic::vector_reduce_or: |
| 733 | + case Intrinsic::vector_reduce_xor: |
704 | 734 | return true; |
| 735 | + case Intrinsic::vector_reduce_fadd: |
| 736 | + case Intrinsic::vector_reduce_fmul: { |
| 737 | + FastMathFlags FMF = getFastMathFlags(Inst); |
| 738 | + if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc()) |
| 739 | + return false; |
| 740 | + |
| 741 | + if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fadd>( |
| 742 | + m_Unless(m_AnyZeroFP()), m_Value()))) |
| 743 | + return false; |
| 744 | + |
| 745 | + if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fmul>( |
| 746 | + m_Unless(m_FPOne()), m_Value()))) |
| 747 | + return false; |
| 748 | + return true; |
| 749 | + } |
705 | 750 | default: |
706 | 751 | return isUniformShape(II); |
707 | 752 | } |
@@ -1268,6 +1313,113 @@ class LowerMatrixIntrinsics { |
1268 | 1313 | return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * |
1269 | 1314 | Result.getNumVectors()); |
1270 | 1315 | } |
| 1316 | + case Intrinsic::vector_reduce_fadd: { |
| 1317 | + Builder.setFastMathFlags(getFastMathFlags(Inst)); |
| 1318 | + auto *I = Inst2ColumnMatrix.find(Inst->getOperand(1)); |
| 1319 | + assert(I != Inst2ColumnMatrix.end()); |
| 1320 | + const MatrixTy &M = I->second; |
| 1321 | + |
| 1322 | + Value *Start = Inst->getOperand(0); |
| 1323 | + Value *ResultV = Builder.CreateVectorSplat( |
| 1324 | + ElementCount::getFixed(M.getStride()), Start); |
| 1325 | + for (auto &Vector : M.vectors()) |
| 1326 | + ResultV = Builder.CreateFAdd(ResultV, Vector); |
| 1327 | + |
| 1328 | + Value *Result = Builder.CreateFAddReduce(Start, ResultV); |
| 1329 | + Inst->replaceAllUsesWith(Result); |
| 1330 | + Result->takeName(Inst); |
| 1331 | + return MatrixTy{Result}; |
| 1332 | + } break; |
| 1333 | + case Intrinsic::vector_reduce_fmul: { |
| 1334 | + Builder.setFastMathFlags(getFastMathFlags(Inst)); |
| 1335 | + auto *I = Inst2ColumnMatrix.find(Inst->getOperand(1)); |
| 1336 | + assert(I != Inst2ColumnMatrix.end()); |
| 1337 | + const MatrixTy &M = I->second; |
| 1338 | + |
| 1339 | + Value *Start = Inst->getOperand(0); |
| 1340 | + Value *ResultV = Builder.CreateVectorSplat( |
| 1341 | + ElementCount::getFixed(M.getStride()), Start); |
| 1342 | + for (auto &Vector : M.vectors()) |
| 1343 | + ResultV = Builder.CreateFMul(ResultV, Vector); |
| 1344 | + |
| 1345 | + Value *Result = Builder.CreateFMulReduce(Start, ResultV); |
| 1346 | + Inst->replaceAllUsesWith(Result); |
| 1347 | + Result->takeName(Inst); |
| 1348 | + return MatrixTy{Result}; |
| 1349 | + } break; |
| 1350 | + case Intrinsic::vector_reduce_fmax: |
| 1351 | + case Intrinsic::vector_reduce_fmaximum: |
| 1352 | + case Intrinsic::vector_reduce_fmin: |
| 1353 | + case Intrinsic::vector_reduce_fminimum: |
| 1354 | + case Intrinsic::vector_reduce_add: |
| 1355 | + case Intrinsic::vector_reduce_and: |
| 1356 | + case Intrinsic::vector_reduce_mul: |
| 1357 | + case Intrinsic::vector_reduce_or: |
| 1358 | + case Intrinsic::vector_reduce_xor: { |
| 1359 | + Builder.setFastMathFlags(getFastMathFlags(Inst)); |
| 1360 | + auto *I = Inst2ColumnMatrix.find(Inst->getOperand(0)); |
| 1361 | + assert(I != Inst2ColumnMatrix.end()); |
| 1362 | + const MatrixTy &M = I->second; |
| 1363 | + |
| 1364 | + auto CreateVReduce = [&](Value *LHS, Value *RHS) { |
| 1365 | + switch (Inst->getIntrinsicID()) { |
| 1366 | + case Intrinsic::vector_reduce_add: |
| 1367 | + return Builder.CreateAdd(LHS, RHS); |
| 1368 | + case Intrinsic::vector_reduce_and: |
| 1369 | + return Builder.CreateAnd(LHS, RHS); |
| 1370 | + case Intrinsic::vector_reduce_fmax: |
| 1371 | + return Builder.CreateMaximum(LHS, RHS); |
| 1372 | + case Intrinsic::vector_reduce_fmaximum: |
| 1373 | + return Builder.CreateMaximumNum(LHS, RHS); |
| 1374 | + case Intrinsic::vector_reduce_fmin: |
| 1375 | + return Builder.CreateMinimum(LHS, RHS); |
| 1376 | + case Intrinsic::vector_reduce_fminimum: |
| 1377 | + return Builder.CreateMinimumNum(LHS, RHS); |
| 1378 | + case Intrinsic::vector_reduce_mul: |
| 1379 | + return Builder.CreateMul(LHS, RHS); |
| 1380 | + case Intrinsic::vector_reduce_or: |
| 1381 | + return Builder.CreateOr(LHS, RHS); |
| 1382 | + case Intrinsic::vector_reduce_xor: |
| 1383 | + return Builder.CreateXor(LHS, RHS); |
| 1384 | + default: |
| 1385 | + llvm_unreachable("unexpected intrinsic"); |
| 1386 | + } |
| 1387 | + }; |
| 1388 | + |
| 1389 | + Value *ResultV = M.getVector(0); |
| 1390 | + for (auto &Vector : drop_begin(M.vectors())) |
| 1391 | + ResultV = CreateVReduce(ResultV, Vector); |
| 1392 | + |
| 1393 | + auto CreateHReduce = [&](Value *V) { |
| 1394 | + switch (Inst->getIntrinsicID()) { |
| 1395 | + case Intrinsic::vector_reduce_add: |
| 1396 | + return Builder.CreateAddReduce(V); |
| 1397 | + case Intrinsic::vector_reduce_and: |
| 1398 | + return Builder.CreateAndReduce(V); |
| 1399 | + case Intrinsic::vector_reduce_fmax: |
| 1400 | + return Builder.CreateFPMaxReduce(V); |
| 1401 | + case Intrinsic::vector_reduce_fmaximum: |
| 1402 | + return Builder.CreateFPMaximumReduce(V); |
| 1403 | + case Intrinsic::vector_reduce_fmin: |
| 1404 | + return Builder.CreateFPMinReduce(V); |
| 1405 | + case Intrinsic::vector_reduce_fminimum: |
| 1406 | + return Builder.CreateFPMinimumReduce(V); |
| 1407 | + case Intrinsic::vector_reduce_mul: |
| 1408 | + return Builder.CreateMulReduce(V); |
| 1409 | + case Intrinsic::vector_reduce_or: |
| 1410 | + return Builder.CreateOrReduce(V); |
| 1411 | + case Intrinsic::vector_reduce_xor: |
| 1412 | + return Builder.CreateXorReduce(V); |
| 1413 | + default: |
| 1414 | + llvm_unreachable("unexpected intrinsic"); |
| 1415 | + } |
| 1416 | + }; |
| 1417 | + |
| 1418 | + Value *Result = CreateHReduce(ResultV); |
| 1419 | + Inst->replaceAllUsesWith(Result); |
| 1420 | + Result->takeName(Inst); |
| 1421 | + return MatrixTy{Result}; |
| 1422 | + } break; |
1271 | 1423 | default: |
1272 | 1424 | break; |
1273 | 1425 | } |
|
0 commit comments