1919
2020#include " llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
2121#include " llvm/ADT/PostOrderIterator.h"
22+ #include " llvm/ADT/STLExtras.h"
2223#include " llvm/ADT/ScopeExit.h"
2324#include " llvm/ADT/SmallSet.h"
2425#include " llvm/ADT/SmallVector.h"
3233#include " llvm/IR/CFG.h"
3334#include " llvm/IR/DataLayout.h"
3435#include " llvm/IR/DebugInfoMetadata.h"
36+ #include " llvm/IR/DerivedTypes.h"
3537#include " llvm/IR/Function.h"
3638#include " llvm/IR/IRBuilder.h"
3739#include " llvm/IR/Instructions.h"
@@ -1337,6 +1339,57 @@ class LowerMatrixIntrinsics {
13371339 return Builder.CreateAdd (Sum, Mul);
13381340 }
13391341
1342+ bool tryLowerShuffleVector (ShuffleVectorInst *Inst) {
1343+ Value *Op0 = Inst->getOperand (0 ), *Op1 = Inst->getOperand (1 );
1344+ SmallVector<int > Mask;
1345+ Inst->getShuffleMask (Mask);
1346+
1347+ auto *Ty = cast<FixedVectorType>(Op0->getType ());
1348+
1349+ if (Mask[0 ] == PoisonMaskElem)
1350+ return false ;
1351+
1352+ // Check if the Mask implies a contiguous extraction, i.e. one column of a
1353+ // column-major matrix (or row of a row-major one).
1354+ for (int I = 1 , E = Mask.size (); I != E; ++I) {
1355+ if (Mask[I] == PoisonMaskElem)
1356+ return false ;
1357+ if (Mask[I-1 ] + 1 != Mask[I])
1358+ return false ;
1359+ }
1360+
1361+ auto VectorForIndex = [&](int Idx) {
1362+ return Idx < int (Ty->getNumElements ()) ? Op0 : Op1;
1363+ };
1364+
1365+ // Check if the Mask extracts from a single source operand.
1366+ Value *Op = VectorForIndex (Mask.front ());
1367+ if (Op != VectorForIndex (Mask.back ()))
1368+ return false ;
1369+
1370+ auto *I = Inst2ColumnMatrix.find (Op);
1371+ if (I == Inst2ColumnMatrix.end ())
1372+ return false ;
1373+
1374+ const MatrixTy &M = I->second ;
1375+
1376+ // Check if the Mask extracts one entire vector from the matrix.
1377+ if (Mask.size () != M.getStride ())
1378+ return false ;
1379+
1380+ // Check if the result would span two of the vectors in the matrix.
1381+ // TODO: we could handle this case by creating a new shuffle, if we see that
1382+ // happening in the wild.
1383+ if (0 != Mask[0 ] % M.getStride ())
1384+ return false ;
1385+
1386+ Value *Result = M.getVector (Mask[0 ] / M.getStride ());
1387+ Inst->replaceAllUsesWith (Result);
1388+ Result->takeName (Inst);
1389+ Inst->eraseFromParent ();
1390+ return true ;
1391+ }
1392+
13401393 // / Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
13411394 // / users with shape information, there's nothing to do: they will use the
13421395 // / cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1404,16 @@ class LowerMatrixIntrinsics {
13511404 ToRemove.push_back (Inst);
13521405 Value *Flattened = nullptr ;
13531406 for (Use &U : llvm::make_early_inc_range (Inst->uses ())) {
1354- if (!ShapeMap.contains (U.getUser ())) {
1355- if (!Flattened)
1356- Flattened = Matrix.embedInVector (Builder);
1357- U.set (Flattened);
1358- }
1407+ if (ShapeMap.contains (U.getUser ()))
1408+ continue ;
1409+
1410+ if (auto *Intr = dyn_cast<ShuffleVectorInst>(U.getUser ()))
1411+ if (tryLowerShuffleVector (Intr))
1412+ continue ;
1413+
1414+ if (!Flattened)
1415+ Flattened = Matrix.embedInVector (Builder);
1416+ U.set (Flattened);
13591417 }
13601418 }
13611419
0 commit comments