diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 756a72e6d97bc..93d56a9a7bd4a 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -19,6 +19,7 @@ #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" @@ -32,6 +33,7 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -1337,6 +1339,57 @@ class LowerMatrixIntrinsics { return Builder.CreateAdd(Sum, Mul); } + bool tryLowerShuffleVector(ShuffleVectorInst *Inst) { + Value *Op0 = Inst->getOperand(0), *Op1 = Inst->getOperand(1); + SmallVector Mask; + Inst->getShuffleMask(Mask); + + auto *Ty = cast(Op0->getType()); + + if (Mask[0] == PoisonMaskElem) + return false; + + // Check if the Mask implies a contiguous extraction, i.e. one column of a + // column-major matrix (or row of a row-major one). + for (int I = 1, E = Mask.size(); I != E; ++I) { + if (Mask[I] == PoisonMaskElem) + return false; + if (Mask[I-1] + 1 != Mask[I]) + return false; + } + + auto VectorForIndex = [&](int Idx) { + return Idx < int(Ty->getNumElements()) ? Op0 : Op1; + }; + + // Check if the Mask extracts from a single source operand. + Value *Op = VectorForIndex(Mask.front()); + if (Op != VectorForIndex(Mask.back())) + return false; + + auto *I = Inst2ColumnMatrix.find(Op); + if (I == Inst2ColumnMatrix.end()) + return false; + + const MatrixTy &M = I->second; + + // Check if the Mask extracts one entire vector from the matrix. + if (Mask.size() != M.getStride()) + return false; + + // Check if the result would span two of the vectors in the matrix. + // TODO: we could handle this case by creating a new shuffle, if we see that + // happening in the wild. + if (0 != Mask[0] % M.getStride()) + return false; + + Value *Result = M.getVector(Mask[0] / M.getStride()); + Inst->replaceAllUsesWith(Result); + Result->takeName(Inst); + Inst->eraseFromParent(); + return true; + } + /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For /// users with shape information, there's nothing to do: they will use the /// cached value when they are lowered. For other users, \p Matrix is @@ -1351,11 +1404,16 @@ class LowerMatrixIntrinsics { ToRemove.push_back(Inst); Value *Flattened = nullptr; for (Use &U : llvm::make_early_inc_range(Inst->uses())) { - if (!ShapeMap.contains(U.getUser())) { - if (!Flattened) - Flattened = Matrix.embedInVector(Builder); - U.set(Flattened); - } + if (ShapeMap.contains(U.getUser())) + continue; + + if (auto *Intr = dyn_cast(U.getUser())) + if (tryLowerShuffleVector(Intr)) + continue; + + if (!Flattened) + Flattened = Matrix.embedInVector(Builder); + U.set(Flattened); } } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll new file mode 100644 index 0000000000000..21f49d2561d4e --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/shuffle.ll @@ -0,0 +1,34 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define <3 x double> @extract_column(ptr %in, ptr %out) { +; CHECK-LABEL: @extract_column( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <3 x double>, ptr [[IN:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 3 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 6 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: ret <3 x double> [[COL_LOAD3]] +; + %inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 1, i32 3, i32 3) + %col = shufflevector <9 x double> %inv, <9 x double> poison, <3 x i32> + ret <3 x double> %col +} + +define <3 x double> @extract_row(ptr %in, ptr %out) { +; CHECK-LABEL: @extract_row( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <3 x double>, ptr [[IN:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 3 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 6 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load volatile <3 x double>, ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD1]], <6 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <3 x double> [[COL_LOAD3]], <3 x double> poison, <6 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <6 x double> [[TMP1]], <6 x double> [[TMP2]], <9 x i32> +; CHECK-NEXT: [[ROW:%.*]] = shufflevector <9 x double> [[TMP3]], <9 x double> poison, <3 x i32> +; CHECK-NEXT: ret <3 x double> [[ROW]] +; + %inv = call <9 x double> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 1, i32 3, i32 3) + %row = shufflevector <9 x double> %inv, <9 x double> poison, <3 x i32> + ret <3 x double> %row +}