Skip to content

Conversation

@cofibrant
Copy link
Contributor

@cofibrant cofibrant commented Oct 9, 2025

To properly support the matrix intrinsics on, e.g., 32-bit platforms (without the need to emit libc calls), LowerMatrixIntrinsics pass should generate code that performs strided index calculations using the same pointer bit-width as the matrix pointers, as determined by the data layout. This patch updates the LowerMatrixInstrics transform to make this the case.

CC @fhahn

@github-actions
Copy link

github-actions bot commented Oct 9, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2025

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: Nathan Corbyn (cofibrant)

Changes

I've also included a commit that slightly refactors how shape information is propagated.

CC @fhahn


Patch is 70.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162646.diff

7 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+63-42)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll (+92-42)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-float.ll (+70-32)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-i32.ll (+70-32)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-store-double.ll (+94-44)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-store-float.ll (+72-34)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/strided-store-i32.ll (+72-34)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 56e0569831e83..408372efdb93b 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -44,7 +44,7 @@
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Compiler.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
 #include "llvm/Transforms/Utils/MatrixUtils.h"
@@ -241,11 +241,16 @@ raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
 
 } // namespace
 
-static bool isUniformShape(Value *V) {
+/// Returns true if \p V is an instruction whose result is of the same shape
+/// as its operands (or if \p V is a non-instruction value).
+static bool isShapePreserving(Value *V) {
   Instruction *I = dyn_cast<Instruction>(V);
   if (!I)
     return true;
 
+  if (isa<SelectInst>(I))
+    return true;
+
   if (I->isBinaryOp())
     return true;
 
@@ -296,6 +301,13 @@ static bool isUniformShape(Value *V) {
   }
 }
 
+static iterator_range<Use *> getShapedOperands(Instruction *I) {
+  auto Ops = I->operands();
+  // Ignore shape information for the predicate operand of a `select`
+  // instruction
+  return isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
+}
+
 /// Return the ShapeInfo for the result of \p I, it it can be determined.
 static std::optional<ShapeInfo>
 computeShapeInfoForInst(Instruction *I,
@@ -325,9 +337,8 @@ computeShapeInfoForInst(Instruction *I,
       return OpShape->second;
   }
 
-  if (isUniformShape(I) || isa<SelectInst>(I)) {
-    auto Ops = I->operands();
-    auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
+  if (isShapePreserving(I)) {
+    auto ShapedOps = getShapedOperands(I);
     // Find the first operand that has a known shape and use that.
     for (auto &Op : ShapedOps) {
       auto OpShape = ShapeMap.find(Op.get());
@@ -633,18 +644,16 @@ class LowerMatrixIntrinsics {
       if (Found != Inst2ColumnMatrix.end()) {
         // FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles
         // that embedInVector created.
-        LLVM_DEBUG(dbgs() << "matrix reshape from " << Found->second.shape()
-                          << " to " << SI << " using at least "
-                          << SplitVecs.size() << " shuffles on behalf of:\n"
-                          << *Inst << '\n');
+        LDBG() << "matrix reshape from " << Found->second.shape() << " to "
+               << SI << " using at least " << SplitVecs.size()
+               << " shuffles on behalf of:\n"
+               << *Inst << '\n';
         ReshapedMatrices++;
       } else if (!ShapeMap.contains(MatrixVal)) {
-        LLVM_DEBUG(
-            dbgs()
-            << "splitting a " << SI << " matrix with " << SplitVecs.size()
-            << " shuffles beacuse we do not have a shape-aware lowering for "
-               "its def:\n"
-            << *Inst << '\n');
+        LDBG() << "splitting a " << SI << " matrix with " << SplitVecs.size()
+               << " shuffles beacuse we do not have a shape-aware lowering for "
+                  "its def:\n"
+               << *Inst << '\n';
         (void)Inst;
         SplitMatrices++;
       } else {
@@ -675,15 +684,14 @@ class LowerMatrixIntrinsics {
             "Matrix shape verification failed, compilation aborted!");
       }
 
-      LLVM_DEBUG(dbgs() << "  not overriding existing shape: "
-                        << SIter->second.NumRows << " "
-                        << SIter->second.NumColumns << " for " << *V << "\n");
+      LDBG() << "  not overriding existing shape: " << SIter->second.NumRows
+             << " " << SIter->second.NumColumns << " for " << *V << "\n";
       return false;
     }
 
     ShapeMap.insert({V, Shape});
-    LLVM_DEBUG(dbgs() << "  " << Shape.NumRows << " x " << Shape.NumColumns
-                      << " for " << *V << "\n");
+    LDBG() << "  " << Shape.NumRows << " x " << Shape.NumColumns << " for "
+           << *V << "\n";
     return true;
   }
 
@@ -703,10 +711,9 @@ class LowerMatrixIntrinsics {
       case Intrinsic::matrix_column_major_store:
         return true;
       default:
-        return isUniformShape(II);
+        break;
       }
-    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
-           isa<SelectInst>(V);
+    return isShapePreserving(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
   }
 
   /// Propagate the shape information of instructions to their users.
@@ -719,7 +726,7 @@ class LowerMatrixIntrinsics {
     // Pop an element for which we guaranteed to have at least one of the
     // operand shapes.  Add the shape for this and then add users to the work
     // list.
-    LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
+    LDBG() << "Forward-propagate shapes:\n";
     while (!WorkList.empty()) {
       Instruction *Inst = WorkList.pop_back_val();
 
@@ -754,7 +761,7 @@ class LowerMatrixIntrinsics {
     // Pop an element with known shape.  Traverse the operands, if their shape
     // derives from the result shape and is unknown, add it and add them to the
     // worklist.
-    LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
+    LDBG() << "Backward-propagate shapes:\n";
     while (!WorkList.empty()) {
       Value *V = WorkList.pop_back_val();
 
@@ -778,7 +785,8 @@ class LowerMatrixIntrinsics {
 
       } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
                               m_Value(MatrixA), m_Value(M), m_Value(N)))) {
-        // Flip dimensions.
+        // We're told MatrixA is M x N so propagate this information directly.
+        // Compare \f computeSahpeInfoForInst where the dimensions are flipped.
         if (setShapeInfo(MatrixA, {M, N}))
           pushInstruction(MatrixA, WorkList);
       } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
@@ -793,10 +801,9 @@ class LowerMatrixIntrinsics {
       } else if (isa<StoreInst>(V)) {
         // Nothing to do.  We forward-propagated to this so we would just
         // backward propagate to an instruction with an already known shape.
-      } else if (isUniformShape(V) || isa<SelectInst>(V)) {
-        auto Ops = cast<Instruction>(V)->operands();
-        auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
-        // Propagate to all operands.
+      } else if (isShapePreserving(V)) {
+        auto ShapedOps = getShapedOperands(cast<Instruction>(V));
+        // Propagate to all shaped operands.
         ShapeInfo Shape = ShapeMap[V];
         for (Use &U : ShapedOps) {
           if (setShapeInfo(U.get(), Shape))
@@ -1295,6 +1302,19 @@ class LowerMatrixIntrinsics {
     return commonAlignment(InitialAlign, ElementSizeInBits / 8);
   }
 
+  IntegerType *getIndexType(Value *Ptr) const {
+    return cast<IntegerType>(DL.getIndexType(Ptr->getType()));
+  }
+
+  Value *getIndex(Value *Ptr, uint64_t V) const {
+    return ConstantInt::get(getIndexType(Ptr), V);
+  }
+
+  Value *truncateToIndexType(Value *Ptr, Value *V, IRBuilder<> &Builder) const {
+    assert(isa<IntegerType>(V->getType()));
+    return Builder.CreateZExtOrTrunc(V, getIndexType(Ptr), V->getName() + ".trunc");
+  }
+
   /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
   /// vectors.
   MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
@@ -1304,6 +1324,7 @@ class LowerMatrixIntrinsics {
     Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
     Value *EltPtr = Ptr;
     MatrixTy Result;
+    Stride = truncateToIndexType(Ptr, Stride, Builder);
     for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
       Value *GEP = computeVectorAddr(
           EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
@@ -1325,14 +1346,14 @@ class LowerMatrixIntrinsics {
                       ShapeInfo ResultShape, Type *EltTy,
                       IRBuilder<> &Builder) {
     Value *Offset = Builder.CreateAdd(
-        Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+        Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);
 
     Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
     auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
                                                    ResultShape.NumColumns);
 
     return loadMatrix(TileTy, TileStart, Align,
-                      Builder.getInt64(MatrixShape.getStride()), IsVolatile,
+                      getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
                       ResultShape, Builder);
   }
 
@@ -1363,14 +1384,15 @@ class LowerMatrixIntrinsics {
                    MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
                    Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
     Value *Offset = Builder.CreateAdd(
-        Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+        Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);
 
     Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
     auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
                                                    StoreVal.getNumColumns());
 
     storeMatrix(TileTy, StoreVal, TileStart, MAlign,
-                Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
+                getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
+                Builder);
   }
 
   /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
@@ -1380,6 +1402,7 @@ class LowerMatrixIntrinsics {
                        IRBuilder<> &Builder) {
     auto *VType = cast<FixedVectorType>(Ty);
     Value *EltPtr = Ptr;
+    Stride = truncateToIndexType(Ptr, Stride, Builder);
     for (auto Vec : enumerate(StoreVal.vectors())) {
       Value *GEP = computeVectorAddr(
           EltPtr,
@@ -2011,18 +2034,17 @@ class LowerMatrixIntrinsics {
             const unsigned TileM = std::min(M - K, unsigned(TileSize));
             MatrixTy A =
                 loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
-                           LShape, Builder.getInt64(I), Builder.getInt64(K),
+                           LShape, getIndex(APtr, I), getIndex(APtr, K),
                            {TileR, TileM}, EltType, Builder);
             MatrixTy B =
                 loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
-                           RShape, Builder.getInt64(K), Builder.getInt64(J),
+                           RShape, getIndex(BPtr, K), getIndex(BPtr, J),
                            {TileM, TileC}, EltType, Builder);
             emitMatrixMultiply(Res, A, B, Builder, true, false,
                                getFastMathFlags(MatMul));
           }
           storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
-                      Builder.getInt64(I), Builder.getInt64(J), EltType,
-                      Builder);
+                      getIndex(CPtr, I), getIndex(CPtr, J), EltType, Builder);
         }
     }
 
@@ -2254,15 +2276,14 @@ class LowerMatrixIntrinsics {
   /// Lower load instructions.
   MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
                      IRBuilder<> &Builder) {
-    return LowerLoad(Inst, Ptr, Inst->getAlign(),
-                     Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
-                     Builder);
+    return LowerLoad(Inst, Ptr, Inst->getAlign(), getIndex(Ptr, SI.getStride()),
+                     Inst->isVolatile(), SI, Builder);
   }
 
   MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
                       Value *Ptr, IRBuilder<> &Builder) {
     return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
-                      Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
+                      getIndex(Ptr, SI.getStride()), Inst->isVolatile(), SI,
                       Builder);
   }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll
index ae7da19e1641e..72a12dc1e7c4c 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load-double.ll
@@ -1,22 +1,40 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+; RUN: opt -passes='lower-matrix-intrinsics' -data-layout='p:64:64' -S < %s | FileCheck %s --check-prefix=PTR64
+; RUN: opt -passes='lower-matrix-intrinsics' -data-layout='p:32:32' -S < %s | FileCheck %s --check-prefix=PTR32
 
 define <9 x double> @strided_load_3x3(ptr %in, i64 %stride) {
-; CHECK-LABEL: @strided_load_3x3(
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
-; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
-; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8
-; CHECK-NEXT:    [[VEC_START5:%.*]] = mul i64 2, [[STRIDE]]
-; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START5]]
-; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP6]], align 8
-; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD4]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD8]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
-; CHECK-NEXT:    ret <9 x double> [[TMP2]]
+; PTR64-LABEL: @strided_load_3x3(
+; PTR64-NEXT:  entry:
+; PTR64-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; PTR64-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
+; PTR64-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; PTR64-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
+; PTR64-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
+; PTR64-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8
+; PTR64-NEXT:    [[VEC_START4:%.*]] = mul i64 2, [[STRIDE]]
+; PTR64-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START4]]
+; PTR64-NEXT:    [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8
+; PTR64-NEXT:    [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD3]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; PTR64-NEXT:    [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD6]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
+; PTR64-NEXT:    [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
+; PTR64-NEXT:    ret <9 x double> [[TMP2]]
+;
+; PTR32-LABEL: @strided_load_3x3(
+; PTR32-NEXT:  entry:
+; PTR32-NEXT:    [[STRIDE:%.*]] = trunc i64 [[STRIDE1:%.*]] to i32
+; PTR32-NEXT:    [[VEC_START:%.*]] = mul i32 0, [[STRIDE]]
+; PTR32-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i32 [[VEC_START]]
+; PTR32-NEXT:    [[COL_LOAD:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
+; PTR32-NEXT:    [[VEC_START1:%.*]] = mul i32 1, [[STRIDE]]
+; PTR32-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i32 [[VEC_START1]]
+; PTR32-NEXT:    [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 8
+; PTR32-NEXT:    [[VEC_START4:%.*]] = mul i32 2, [[STRIDE]]
+; PTR32-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN]], i32 [[VEC_START4]]
+; PTR32-NEXT:    [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8
+; PTR32-NEXT:    [[TMP0:%.*]] = shufflevector <3 x double> [[COL_LOAD]], <3 x double> [[COL_LOAD3]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; PTR32-NEXT:    [[TMP1:%.*]] = shufflevector <3 x double> [[COL_LOAD6]], <3 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
+; PTR32-NEXT:    [[TMP2:%.*]] = shufflevector <6 x double> [[TMP0]], <6 x double> [[TMP1]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
+; PTR32-NEXT:    ret <9 x double> [[TMP2]]
 ;
 entry:
   %load = call <9 x double> @llvm.matrix.column.major.load.v9f64.i64(ptr %in, i64 %stride, i1 false, i32 3, i32 3)
@@ -26,12 +44,20 @@ entry:
 declare <9 x double> @llvm.matrix.column.major.load.v9f64.i64(ptr, i64, i1, i32, i32)
 
 define <9 x double> @strided_load_9x1(ptr %in, i64 %stride) {
-; CHECK-LABEL: @strided_load_9x1(
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <9 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    ret <9 x double> [[COL_LOAD]]
+; PTR64-LABEL: @strided_load_9x1(
+; PTR64-NEXT:  entry:
+; PTR64-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; PTR64-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
+; PTR64-NEXT:    [[COL_LOAD:%.*]] = load <9 x double>, ptr [[VEC_GEP]], align 8
+; PTR64-NEXT:    ret <9 x double> [[COL_LOAD]]
+;
+; PTR32-LABEL: @strided_load_9x1(
+; PTR32-NEXT:  entry:
+; PTR32-NEXT:    [[STRIDE_TRUNC:%.*]] = trunc i64 [[STRIDE:%.*]] to i32
+; PTR32-NEXT:    [[VEC_START:%.*]] = mul i32 0, [[STRIDE_TRUNC]]
+; PTR32-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i32 [[VEC_START]]
+; PTR32-NEXT:    [[COL_LOAD:%.*]] = load <9 x double>, ptr [[VEC_GEP]], align 8
+; PTR32-NEXT:    ret <9 x double> [[COL_LOAD]]
 ;
 entry:
   %load = call <9 x double> @llvm.matrix.column.major.load.v9f64.i64(ptr %in, i64 %stride, i1 false, i32 9, i32 1)
@@ -41,16 +67,28 @@ entry:
 declare <8 x double> @llvm.matrix.column.major.load.v8f64.i64(ptr, i64, i1, i32, i32)
 
 define <8 x double> @strided_load_4x2(ptr %in, i64 %stride) {
-; CHECK-LABEL: @strided_load_4x2(
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <4 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
-; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
-; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <4 x double>, ptr [[VEC_GEP2]], align 8
-; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <4 x double> [[COL_LOAD]], <4 x double> [[COL_LOAD4]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    ret <8 x double> [[TMP0]]
+; PTR64-LABEL: @strided_load_4x2(
+; PTR64-NEXT:  entry:
+; PTR64-NEXT:    [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; PTR64-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i64 [[VEC_START]]
+; PTR64-NEXT:    [[COL_LOAD:%.*]] = load <4 x double>, ptr [[VEC_GEP]], align 8
+; PTR64-NEXT:    [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]]
+; PTR64-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN]], i64 [[VEC_START1]]
+; PTR64-NEXT:    [[COL_LOAD3:%.*]] = load <4 x double>, ptr [[VEC_GEP2]], align 8
+; PTR64-NEXT:    [[TMP0:%.*]] = shufflevector <4 x double> [[COL_LOAD]], <4 x double> [[COL_LOAD3]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; PTR64-NEXT:    ret <8 x double> [[TMP0]]
+;
+; PTR32-LABEL: @strided_load_4x2(
+; PTR32-NEXT:  entry:
+; PTR32-NEXT:    [[STRIDE_TRUNC:%.*]] = trunc i64 [[STRIDE:%.*]] to i32
+; PTR32-NEXT:    [[VEC_START:%.*]] = mul i32 0, [[STRIDE_TRUNC]]
+; PTR32-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN:%.*]], i32 [[VEC_START]]
+; PTR32-NEXT:    [[COL_LOAD:%.*]] = load <4 x double>, ptr [[VEC_GEP]], align...
[truncated]

@cofibrant cofibrant changed the title [Matrix] Use data layout index type for lower matrix intrinsics [Matrix] Use data layout index type for lowering matrix intrinsics Oct 9, 2025
@fhahn fhahn requested review from anemet, fhahn and jroelofs and removed request for fhahn October 9, 2025 12:58
@cofibrant cofibrant force-pushed the cofibrant/matrix-intrinsics-index branch 2 times, most recently from 09c6e8d to bcfb878 Compare October 9, 2025 15:38
@llvmbot llvmbot added the llvm:ir label Oct 9, 2025
@fhahn fhahn requested a review from farzonl October 10, 2025 08:48
Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. Could you also add a description for the test, explaining the rationale for the change?

@jroelofs jroelofs requested a review from lei137 October 10, 2025 16:26
@cofibrant cofibrant force-pushed the cofibrant/matrix-intrinsics-index branch from bcfb878 to a845260 Compare October 13, 2025 08:41
@cofibrant
Copy link
Contributor Author

I've added a small explanation of the rationale for this change in data-layout.ll and a pointer to that explanation in data-layout-fused-multiply.ll.

@cofibrant cofibrant force-pushed the cofibrant/matrix-intrinsics-index branch from a845260 to 2a7edd1 Compare October 13, 2025 09:37
Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@fhahn fhahn merged commit 625aa09 into llvm:main Oct 13, 2025
11 of 15 checks passed
@github-actions
Copy link

@cofibrant Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Oct 13, 2025
…trinsics (#162646)

To properly support the matrix intrinsics on, e.g., 32-bit platforms
(without the need to emit `libc` calls), `LowerMatrixIntrinsics` pass
should generate code that performs strided index calculations using the
same pointer bit-width as the matrix pointers, as determined by the data
layout. This patch updates the `LowerMatrixInstrics` transform to make
this the case.

PR: llvm/llvm-project#162646
akadutta pushed a commit to akadutta/llvm-project that referenced this pull request Oct 14, 2025
…lvm#162646)

To properly support the matrix intrinsics on, e.g., 32-bit platforms
(without the need to emit `libc` calls), `LowerMatrixIntrinsics` pass
should generate code that performs strided index calculations using the
same pointer bit-width as the matrix pointers, as determined by the data
layout. This patch updates the `LowerMatrixInstrics` transform to make
this the case.

PR: llvm#162646
akadutta pushed a commit to akadutta/llvm-project that referenced this pull request Oct 14, 2025
cofibrant added a commit to cofibrant/llvm-project that referenced this pull request Oct 16, 2025
…lvm#162646)

To properly support the matrix intrinsics on, e.g., 32-bit platforms
(without the need to emit `libc` calls), `LowerMatrixIntrinsics` pass
should generate code that performs strided index calculations using the
same pointer bit-width as the matrix pointers, as determined by the data
layout. This patch updates the `LowerMatrixInstrics` transform to make
this the case.

PR: llvm#162646
(cherry picked from commit 625aa09)
fhahn added a commit to swiftlang/llvm-project that referenced this pull request Oct 18, 2025
…-intrinsics-index

🍒 [Matrix] Use data layout index type for lowering matrix intrinsics (llvm#162646)


rdar://97204460
@cofibrant cofibrant deleted the cofibrant/matrix-intrinsics-index branch October 22, 2025 14:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants