diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 5a518244a80ca..1d05cd47fa0cf 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/LoopInfo.h" @@ -40,6 +41,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -52,6 +54,10 @@ using namespace PatternMatch; #define DEBUG_TYPE "lower-matrix-intrinsics" +STATISTIC(FlattenedMatrices, "Number of matrix flattenings"); +STATISTIC(ReshapedMatrices, "Number of matrix reshapes"); +STATISTIC(SplitMatrices, "Number of matrix splits"); + static cl::opt FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, cl::desc("Enable/disable fusing matrix instructions.")); @@ -221,7 +227,16 @@ struct ShapeInfo { /// Returns the transposed shape. ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } + + friend raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI); + + LLVM_DUMP_METHOD void dump() const { dbgs() << *this << '\n'; } }; + +raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) { + return OS << SI.NumRows << 'x' << SI.NumColumns; +} + } // namespace static bool isUniformShape(Value *V) { @@ -466,6 +481,8 @@ class LowerMatrixIntrinsics { return getNumColumns(); } + ShapeInfo shape() const { return {getNumRows(), getNumColumns()}; } + /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the /// matrix is column-major, the result vector is extracted from a column /// vector, otherwise from a row vector. @@ -578,6 +595,28 @@ class LowerMatrixIntrinsics { SplitVecs.push_back(V); } + LLVM_DEBUG(if (Instruction *Inst = dyn_cast(MatrixVal)) { + if (Found != Inst2ColumnMatrix.end()) { + // FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles + // that embedInVector created. + dbgs() << "matrix reshape from " << Found->second.shape() << " to " + << SI << " using at least " << SplitVecs.size() + << " shuffles on behalf of:\n" + << *Inst << '\n'; + ReshapedMatrices++; + } else if (!ShapeMap.contains(MatrixVal)) { + dbgs() << "splitting a " << SI << " matrix with " << SplitVecs.size() + << " shuffles beacuse we do not have a shape-aware lowering for " + "its def:\n" + << *Inst << '\n'; + SplitMatrices++; + } else { + // The ShapeMap has it, so it's a case where we're being lowered + // before the def, and we expect that InstCombine will clean things up + // afterward. + } + }); + return {SplitVecs}; } @@ -1386,11 +1425,21 @@ class LowerMatrixIntrinsics { ToRemove.push_back(Inst); Value *Flattened = nullptr; for (Use &U : llvm::make_early_inc_range(Inst->uses())) { - if (!ShapeMap.contains(U.getUser())) { - if (!Flattened) - Flattened = Matrix.embedInVector(Builder); - U.set(Flattened); + if (ShapeMap.contains(U.getUser())) + continue; + + if (!Flattened) { + Flattened = Matrix.embedInVector(Builder); + LLVM_DEBUG( + if (Instruction *User = dyn_cast(U.getUser())) dbgs() + << "flattening a " << Matrix.shape() << " matrix:\n" + << *Inst + << "\nbecause we do not have a shape-aware lowering for its " + "user:\n" + << *User << '\n';); + FlattenedMatrices++; } + U.set(Flattened); } } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/flatten.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/flatten.ll new file mode 100644 index 0000000000000..8a928c36baffd --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/flatten.ll @@ -0,0 +1,61 @@ +; RUN: opt -passes=lower-matrix-intrinsics -debug-only=lower-matrix-intrinsics -disable-output < %s 2>&1 | FileCheck %s --check-prefix=CHECK +; REQUIRES: asserts + +define void @diag_3x3(ptr %in, ptr %out) { + %inv = call <9 x float> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 false, i32 3, i32 3) + %diag = shufflevector <9 x float> %inv, <9 x float> poison, <3 x i32> + store <3 x float> %diag, ptr %out + ret void +} +; CHECK-LABEL: flattening a 3x3 matrix: +; CHECK-NEXT: %{{.*}} = call <9 x float> @llvm.matrix.column.major.load.v9f32.i64(ptr %{{.*}}, i64 3, i1 false, i32 3, i32 3) +; CHECK-NEXT: because we do not have a shape-aware lowering for its user: +; CHECK-NEXT: %{{.*}} = shufflevector <9 x float> %{{.*}}, <9 x float> poison, <3 x i32> + +define void @reshape(ptr %in, ptr %out) { +entry: + %0 = load <4 x double>, ptr %in, align 8 + %1 = tail call <4 x double> @llvm.matrix.transpose.v4f64(<4 x double> %0, i32 4, i32 1) + %2 = tail call <4 x double> @llvm.matrix.transpose.v4f64(<4 x double> %1, i32 1, i32 4) + %3 = tail call <4 x double> @llvm.matrix.transpose.v4f64(<4 x double> %2, i32 2, i32 2) + %4 = tail call <4 x double> @llvm.matrix.transpose.v4f64(<4 x double> %3, i32 2, i32 2) + %5 = tail call <4 x double> @llvm.matrix.transpose.v4f64(<4 x double> %4, i32 2, i32 2) + store <4 x double> %5, ptr %out, align 8 + ret void +} +; CHECK-LABEL: matrix reshape from 4x1 to 2x2 using at least 2 shuffles on behalf of: +; CHECK-NEXT: %{{.*}} = load <4 x double>, ptr %{{.*}}, align 8 + +define void @multiply_ntt(ptr %A, ptr %B, ptr %C, ptr %R) { +entry: + %a = load <6 x double>, ptr %A, align 16 + %b = load <6 x double>, ptr %B, align 16 + %c = load <8 x double>, ptr %C, align 16 + %b_t = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %b, i32 2, i32 3) + %c_t = call <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double> %c, i32 4, i32 2) + %m1 = call <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double> %b_t, <8 x double> %c_t, i32 3, i32 2, i32 4) + %m2 = call <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double> %m1, i32 2, i32 3, i32 4) + store <8 x double> %m2, ptr %R, align 16 + ret void +} +; CHECK-LABEL: flattening a 2x3 matrix: +; CHECK-NEXT: %{{.*}} = load <6 x double>, ptr %{{.*}}, align 16 +; CHECK-NEXT: because we do not have a shape-aware lowering for its user: +; CHECK-NEXT: %{{.*}} = shufflevector <6 x double> %{{.*}}, <6 x double> poison, <2 x i32> + +; CHECK-LABEL: flattening a 4x3 matrix: +; CHECK-NEXT: %{{.*}} = call <12 x double> @llvm.matrix.multiply.v12f64.v8f64.v6f64(<8 x double> %{{.*}}, <6 x double> %{{.*}}, i32 4, i32 2, i32 3) +; CHECK-NEXT: because we do not have a shape-aware lowering for its user: +; CHECK-NEXT: %{{.*}} = shufflevector <12 x double> %{{.*}}, <12 x double> poison, <4 x i32> + + +define void @redundant_transpose_of_shuffle(<4 x float> %m, ptr %dst) { +entry: + %shuffle = shufflevector <4 x float> %m, <4 x float> zeroinitializer, <4 x i32> zeroinitializer + %t = tail call <4 x float> @llvm.matrix.transpose.v3f32(<4 x float> %shuffle, i32 1, i32 4) + store <4 x float> %t, ptr %dst, align 4 + ret void +} + +; CHECK-LABEL: splitting a 4x1 matrix with 1 shuffles beacuse we do not have a shape-aware lowering for its def: +; CHECK-NEXT: %{{.*}} = shufflevector <4 x float> %{{.*}}, <4 x float> zeroinitializer, <4 x i32> zeroinitializer \ No newline at end of file