-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[Matrix] Use data layout index type for lowering matrix intrinsics #162646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-llvm-transforms Author: Nathan Corbyn (cofibrant) ChangesI've also included a commit that slightly refactors how shape information is propagated. CC @fhahn Patch is 70.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162646.diff 7 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 56e0569831e83..408372efdb93b 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -44,7 +44,7 @@
#include "llvm/Support/Alignment.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/MatrixUtils.h"
@@ -241,11 +241,16 @@ raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
} // namespace
-static bool isUniformShape(Value *V) {
+/// Returns true if \p V is an instruction whose result is of the same shape
+/// as its operands (or if \p V is a non-instruction value).
+static bool isShapePreserving(Value *V) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I)
return true;
+ if (isa<SelectInst>(I))
+ return true;
+
if (I->isBinaryOp())
return true;
@@ -296,6 +301,13 @@ static bool isUniformShape(Value *V) {
}
}
+static iterator_range<Use *> getShapedOperands(Instruction *I) {
+ auto Ops = I->operands();
+ // Ignore shape information for the predicate operand of a `select`
+ // instruction
+ return isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
+}
+
/// Return the ShapeInfo for the result of \p I, it it can be determined.
static std::optional<ShapeInfo>
computeShapeInfoForInst(Instruction *I,
@@ -325,9 +337,8 @@ computeShapeInfoForInst(Instruction *I,
return OpShape->second;
}
- if (isUniformShape(I) || isa<SelectInst>(I)) {
- auto Ops = I->operands();
- auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
+ if (isShapePreserving(I)) {
+ auto ShapedOps = getShapedOperands(I);
// Find the first operand that has a known shape and use that.
for (auto &Op : ShapedOps) {
auto OpShape = ShapeMap.find(Op.get());
@@ -633,18 +644,16 @@ class LowerMatrixIntrinsics {
if (Found != Inst2ColumnMatrix.end()) {
// FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles
// that embedInVector created.
- LLVM_DEBUG(dbgs() << "matrix reshape from " << Found->second.shape()
- << " to " << SI << " using at least "
- << SplitVecs.size() << " shuffles on behalf of:\n"
- << *Inst << '\n');
+ LDBG() << "matrix reshape from " << Found->second.shape() << " to "
+ << SI << " using at least " << SplitVecs.size()
+ << " shuffles on behalf of:\n"
+ << *Inst << '\n';
ReshapedMatrices++;
} else if (!ShapeMap.contains(MatrixVal)) {
- LLVM_DEBUG(
- dbgs()
- << "splitting a " << SI << " matrix with " << SplitVecs.size()
- << " shuffles beacuse we do not have a shape-aware lowering for "
- "its def:\n"
- << *Inst << '\n');
+ LDBG() << "splitting a " << SI << " matrix with " << SplitVecs.size()
+ << " shuffles beacuse we do not have a shape-aware lowering for "
+ "its def:\n"
+ << *Inst << '\n';
(void)Inst;
SplitMatrices++;
} else {
@@ -675,15 +684,14 @@ class LowerMatrixIntrinsics {
"Matrix shape verification failed, compilation aborted!");
}
- LLVM_DEBUG(dbgs() << " not overriding existing shape: "
- << SIter->second.NumRows << " "
- << SIter->second.NumColumns << " for " << *V << "\n");
+ LDBG() << " not overriding existing shape: " << SIter->second.NumRows
+ << " " << SIter->second.NumColumns << " for " << *V << "\n";
return false;
}
ShapeMap.insert({V, Shape});
- LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
- << " for " << *V << "\n");
+ LDBG() << " " << Shape.NumRows << " x " << Shape.NumColumns << " for "
+ << *V << "\n";
return true;
}
@@ -703,10 +711,9 @@ class LowerMatrixIntrinsics {
case Intrinsic::matrix_column_major_store:
return true;
default:
- return isUniformShape(II);
+ break;
}
- return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
- isa<SelectInst>(V);
+ return isShapePreserving(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
}
/// Propagate the shape information of instructions to their users.
@@ -719,7 +726,7 @@ class LowerMatrixIntrinsics {
// Pop an element for which we guaranteed to have at least one of the
// operand shapes. Add the shape for this and then add users to the work
// list.
- LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
+ LDBG() << "Forward-propagate shapes:\n";
while (!WorkList.empty()) {
Instruction *Inst = WorkList.pop_back_val();
@@ -754,7 +761,7 @@ class LowerMatrixIntrinsics {
// Pop an element with known shape. Traverse the operands, if their shape
// derives from the result shape and is unknown, add it and add them to the
// worklist.
- LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
+ LDBG() << "Backward-propagate shapes:\n";
while (!WorkList.empty()) {
Value *V = WorkList.pop_back_val();
@@ -778,7 +785,8 @@ class LowerMatrixIntrinsics {
} else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(MatrixA), m_Value(M), m_Value(N)))) {
- // Flip dimensions.
+ // We're told MatrixA is M x N so propagate this information directly.
+ // Compare \f computeSahpeInfoForInst where the dimensions are flipped.
if (setShapeInfo(MatrixA, {M, N}))
pushInstruction(MatrixA, WorkList);
} else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
@@ -793,10 +801,9 @@ class LowerMatrixIntrinsics {
} else if (isa<StoreInst>(V)) {
// Nothing to do. We forward-propagated to this so we would just
// backward propagate to an instruction with an already known shape.
- } else if (isUniformShape(V) || isa<SelectInst>(V)) {
- auto Ops = cast<Instruction>(V)->operands();
- auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
- // Propagate to all operands.
+ } else if (isShapePreserving(V)) {
+ auto ShapedOps = getShapedOperands(cast<Instruction>(V));
+ // Propagate to all shaped operands.
ShapeInfo Shape = ShapeMap[V];
for (Use &U : ShapedOps) {
if (setShapeInfo(U.get(), Shape))
@@ -1295,6 +1302,19 @@ class LowerMatrixIntrinsics {
return commonAlignment(InitialAlign, ElementSizeInBits / 8);
}
+ IntegerType *getIndexType(Value *Ptr) const {
+ return cast<IntegerType>(DL.getIndexType(Ptr->getType()));
+ }
+
+ Value *getIndex(Value *Ptr, uint64_t V) const {
+ return ConstantInt::get(getIndexType(Ptr), V);
+ }
+
+ Value *truncateToIndexType(Value *Ptr, Value *V, IRBuilder<> &Builder) const {
+ assert(isa<IntegerType>(V->getType()));
+ return Builder.CreateZExtOrTrunc(V, getIndexType(Ptr), V->getName() + ".trunc");
+ }
+
/// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
/// vectors.
MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
@@ -1304,6 +1324,7 @@ class LowerMatrixIntrinsics {
Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
Value *EltPtr = Ptr;
MatrixTy Result;
+ Stride = truncateToIndexType(Ptr, Stride, Builder);
for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
Value *GEP = computeVectorAddr(
EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
@@ -1325,14 +1346,14 @@ class LowerMatrixIntrinsics {
ShapeInfo ResultShape, Type *EltTy,
IRBuilder<> &Builder) {
Value *Offset = Builder.CreateAdd(
- Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+ Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);
Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
ResultShape.NumColumns);
return loadMatrix(TileTy, TileStart, Align,
- Builder.getInt64(MatrixShape.getStride()), IsVolatile,
+ getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
ResultShape, Builder);
}
@@ -1363,14 +1384,15 @@ class LowerMatrixIntrinsics {
MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
Value *Offset = Builder.CreateAdd(
- Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+ Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);
Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
StoreVal.getNumColumns());
storeMatrix(TileTy, StoreVal, TileStart, MAlign,
- Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
+ getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
+ Builder);
}
/// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
@@ -1380,6 +1402,7 @@ class LowerMatrixIntrinsics {
IRBuilder<> &Builder) {
auto *VType = cast<FixedVectorType>(Ty);
Value *EltPtr = Ptr;
+ Stride = truncateToIndexType(Ptr, Stride, Builder);
for (auto Vec : enumerate(StoreVal.vectors())) {
Value *GEP = computeVectorAddr(
EltPtr,
@@ -2011,18 +2034,17 @@ class LowerMatrixIntrinsics {
const unsigned TileM = std::min(M - K, unsigned(TileSize));
MatrixTy A =
loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
- LShape, Builder.getInt64(I), Builder.getInt64(K),
+ LShape, getIndex(APtr, I), getIndex(APtr, K),
{TileR, TileM}, EltType, Builder);
MatrixTy B =
loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
- RShape, Builder.getInt64(K), Builder.getInt64(J),
+ RShape, getIndex(BPtr, K), getIndex(BPtr, J),
{TileM, TileC}, EltType, Builder);
emitMatrixMultiply(Res, A, B, Builder, true, false,
getFastMathFlags(MatMul));
}
storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
- Builder.getInt64(I), Builder.getInt64(J), EltType,
- Builder);
+ getIndex(CPtr, I), getIndex(CPtr, J), EltType, Builder);
}
}
@@ -2254,15 +2276,14 @@ class LowerMatrixIntrinsics {
/// Lower load instructions.
MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
IRBuilder<> &Builder) {
- return LowerLoad(Inst, Ptr, Inst->getAlign(),
- Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
- Builder);
+ return LowerLoad(Inst, Ptr, Inst->getAlign(), getIndex(Ptr, SI.getStride()),
+ Inst->isVolatile(), SI, Builder);
}
MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
Value *Ptr, IRBuilder<> &Builder) {
return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
- Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
+ getIndex(Ptr, SI.getStride()), Inst->isVolatile(), SI,
Builder);
}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll
index ae7da19e1641e..72a12dc1e7c4c 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll
@@ -1,22 +1,40 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+; RUN: opt -passes='lower-matrix-intrinsics' -data-layout='p:64:64' -S < %s | FileCheck %s --check-prefix=PTR64
+; RUN: opt -passes='lower-matrix-intrinsics' -data-layout='p:32:32' -S < %s | FileCheck %s --check-prefix=PTR32
define <9 x double> @strided_load_3x3(ptr %in, i64 %stride) {
-; CHECK-LABEL: @strided_load_3x3(
-; CHECK-NEXT: entry:
-; CHECK-NEXT: [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
-; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT: [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
-; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
-; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8
-; CHECK-NEXT: [[VEC_START5:%.*]] = mul i64 2, [[STRIDE]]
-; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START5]]
-; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP6]], align 8
-; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD4]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
-; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD8]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
-; CHECK-NEXT: ret <9 x double> [[TMP2]]
+; PTR64-LABEL: @strided_load_3x3(
+; PTR64-NEXT: entry:
+; PTR64-NEXT: [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; PTR64-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
+; PTR64-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; PTR64-NEXT: [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
+; PTR64-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
+; PTR64-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8
+; PTR64-NEXT: [[VEC_START4:%.*]] = mul i64 2, [[STRIDE]]
+; PTR64-NEXT: [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START4]]
+; PTR64-NEXT: [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8
+; PTR64-NEXT: [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD3]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; PTR64-NEXT: [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD6]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
+; PTR64-NEXT: [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
+; PTR64-NEXT: ret <9 x double> [[TMP2]]
+;
+; PTR32-LABEL: @strided_load_3x3(
+; PTR32-NEXT: entry:
+; PTR32-NEXT: [[STRIDE:%.*]] = trunc i64 [[STRIDE1:%.*]] to i32
+; PTR32-NEXT: [[VEC_START:%.*]] = mul i32 0, [[STRIDE]]
+; PTR32-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i32 [[VEC_START]]
+; PTR32-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; PTR32-NEXT: [[VEC_START1:%.*]] = mul i32 1, [[STRIDE]]
+; PTR32-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i32 [[VEC_START1]]
+; PTR32-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8
+; PTR32-NEXT: [[VEC_START4:%.*]] = mul i32 2, [[STRIDE]]
+; PTR32-NEXT: [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN]], i32 [[VEC_START4]]
+; PTR32-NEXT: [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8
+; PTR32-NEXT: [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD3]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; PTR32-NEXT: [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD6]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
+; PTR32-NEXT: [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
+; PTR32-NEXT: ret <9 x double> [[TMP2]]
;
entry:
%load = call <9 x double> @llvm.matrix.column.major.load.v9f64.i64(ptr %in, i64 %stride, i1 false, i32 3, i32 3)
@@ -26,12 +44,20 @@ entry:
declare <9 x double> @llvm.matrix.column.major.load.v9f64.i64(ptr, i64, i1, i32, i32)
define <9 x double> @strided_load_9x1(ptr %in, i64 %stride) {
-; CHECK-LABEL: @strided_load_9x1(
-; CHECK-NEXT: entry:
-; CHECK-NEXT: [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
-; CHECK-NEXT: [[COL_LOAD:%.*]] = load <9 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT: ret <9 x double> [[COL_LOAD]]
+; PTR64-LABEL: @strided_load_9x1(
+; PTR64-NEXT: entry:
+; PTR64-NEXT: [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; PTR64-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
+; PTR64-NEXT: [[COL_LOAD:%.*]] = load <9 x double>, ptr [[VEC_GEP]], align 8
+; PTR64-NEXT: ret <9 x double> [[COL_LOAD]]
+;
+; PTR32-LABEL: @strided_load_9x1(
+; PTR32-NEXT: entry:
+; PTR32-NEXT: [[STRIDE_TRUNC:%.*]] = trunc i64 [[STRIDE:%.*]] to i32
+; PTR32-NEXT: [[VEC_START:%.*]] = mul i32 0, [[STRIDE_TRUNC]]
+; PTR32-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i32 [[VEC_START]]
+; PTR32-NEXT: [[COL_LOAD:%.*]] = load <9 x double>, ptr [[VEC_GEP]], align 8
+; PTR32-NEXT: ret <9 x double> [[COL_LOAD]]
;
entry:
%load = call <9 x double> @llvm.matrix.column.major.load.v9f64.i64(ptr %in, i64 %stride, i1 false, i32 9, i32 1)
@@ -41,16 +67,28 @@ entry:
declare <8 x double> @llvm.matrix.column.major.load.v8f64.i64(ptr, i64, i1, i32, i32)
define <8 x double> @strided_load_4x2(ptr %in, i64 %stride) {
-; CHECK-LABEL: @strided_load_4x2(
-; CHECK-NEXT: entry:
-; CHECK-NEXT: [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
-; CHECK-NEXT: [[COL_LOAD:%.*]] = load <4 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT: [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
-; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
-; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <4 x double>, ptr [[VEC_GEP2]], align 8
-; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <4 x double> [[COL_LOAD]], <4 x double> [[COL_LOAD4]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: ret <8 x double> [[TMP0]]
+; PTR64-LABEL: @strided_load_4x2(
+; PTR64-NEXT: entry:
+; PTR64-NEXT: [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; PTR64-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
+; PTR64-NEXT: [[COL_LOAD:%.*]] = load <4 x double>, ptr [[VEC_GEP]], align 8
+; PTR64-NEXT: [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
+; PTR64-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
+; PTR64-NEXT: [[COL_LOAD3:%.*]] = load <4 x double>, ptr [[VEC_GEP2]], align 8
+; PTR64-NEXT: [[TMP0:%.*]] = shufflevector <4 x double> [[COL_LOAD]], <4 x double> [[COL_LOAD3]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; PTR64-NEXT: ret <8 x double> [[TMP0]]
+;
+; PTR32-LABEL: @strided_load_4x2(
+; PTR32-NEXT: entry:
+; PTR32-NEXT: [[STRIDE_TRUNC:%.*]] = trunc i64 [[STRIDE:%.*]] to i32
+; PTR32-NEXT: [[VEC_START:%.*]] = mul i32 0, [[STRIDE_TRUNC]]
+; PTR32-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i32 [[VEC_START]]
+; PTR32-NEXT: [[COL_LOAD:%.*]] = load <4 x double>, ptr [[VEC_GEP]], align...
[truncated]
|
09c6e8d to
bcfb878
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update. Could you also add a description for the test, explaining the rationale for the change?
bcfb878 to
a845260
Compare
|
I've added a small explanation of the rationale for this change in |
a845260 to
2a7edd1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
|
@cofibrant Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
…trinsics (#162646) To properly support the matrix intrinsics on, e.g., 32-bit platforms (without the need to emit `libc` calls), `LowerMatrixIntrinsics` pass should generate code that performs strided index calculations using the same pointer bit-width as the matrix pointers, as determined by the data layout. This patch updates the `LowerMatrixInstrics` transform to make this the case. PR: llvm/llvm-project#162646
…lvm#162646) To properly support the matrix intrinsics on, e.g., 32-bit platforms (without the need to emit `libc` calls), `LowerMatrixIntrinsics` pass should generate code that performs strided index calculations using the same pointer bit-width as the matrix pointers, as determined by the data layout. This patch updates the `LowerMatrixInstrics` transform to make this the case. PR: llvm#162646
…lvm#162646) To properly support the matrix intrinsics on, e.g., 32-bit platforms (without the need to emit `libc` calls), `LowerMatrixIntrinsics` pass should generate code that performs strided index calculations using the same pointer bit-width as the matrix pointers, as determined by the data layout. This patch updates the `LowerMatrixInstrics` transform to make this the case. PR: llvm#162646 (cherry picked from commit 625aa09)
…-intrinsics-index 🍒 [Matrix] Use data layout index type for lowering matrix intrinsics (llvm#162646) rdar://97204460
To properly support the matrix intrinsics on, e.g., 32-bit platforms (without the need to emit
libccalls),LowerMatrixIntrinsicspass should generate code that performs strided index calculations using the same pointer bit-width as the matrix pointers, as determined by the data layout. This patch updates theLowerMatrixInstricstransform to make this the case.CC @fhahn