Skip to content

Conversation

@jroelofs
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 16, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jon Roelofs (jroelofs)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/144369.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+21-27)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 1e37f40fa9d52..ece0bb56fff01 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1146,24 +1146,24 @@ class LowerMatrixIntrinsics {
       Value *Op1;
       Value *Op2;
       MatrixTy Result;
+      IRBuilder<> Builder(Inst);
       if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
-        Result = VisitBinaryOperator(BinOp, SI);
+        Result = VisitBinaryOperator(BinOp, SI, Builder);
       else if (auto *Cast = dyn_cast<CastInst>(Inst))
-        Result = VisitCastInstruction(Cast, SI);
+        Result = VisitCastInstruction(Cast, SI, Builder);
       else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
-        Result = VisitUnaryOperator(UnOp, SI);
+        Result = VisitUnaryOperator(UnOp, SI, Builder);
       else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
-        Result = VisitIntrinsicInst(Intr, SI);
+        Result = VisitIntrinsicInst(Intr, SI, Builder);
       else if (auto *Select = dyn_cast<SelectInst>(Inst))
-        Result = VisitSelectInst(Select, SI);
+        Result = VisitSelectInst(Select, SI, Builder);
       else if (match(Inst, m_Load(m_Value(Op1))))
-        Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1);
+        Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
-        Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2);
+        Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
       else
         continue;
 
-      IRBuilder<> Builder(Inst);
       finalizeLowering(Inst, Result, Builder);
       Changed = true;
     }
@@ -1204,7 +1204,8 @@ class LowerMatrixIntrinsics {
   }
 
   /// Replace intrinsic calls.
-  MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI) {
+  MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI,
+                              IRBuilder<> &Builder) {
     assert(Inst->getCalledFunction() &&
            Inst->getCalledFunction()->isIntrinsic());
 
@@ -1219,7 +1220,6 @@ class LowerMatrixIntrinsics {
       return LowerColumnMajorStore(Inst);
     case Intrinsic::abs:
     case Intrinsic::fabs: {
-      IRBuilder<> Builder(Inst);
       MatrixTy Result;
       MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder);
       Builder.setFastMathFlags(getFastMathFlags(Inst));
@@ -1298,7 +1298,6 @@ class LowerMatrixIntrinsics {
                       ShapeInfo MatrixShape, Value *I, Value *J,
                       ShapeInfo ResultShape, Type *EltTy,
                       IRBuilder<> &Builder) {
-
     Value *Offset = Builder.CreateAdd(
         Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
 
@@ -2228,26 +2227,24 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower load instructions.
-  MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) {
-    IRBuilder<> Builder(Inst);
+  MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
+                     IRBuilder<> &Builder) {
     return LowerLoad(Inst, Ptr, Inst->getAlign(),
                      Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
   }
 
   MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
-                      Value *Ptr) {
-    IRBuilder<> Builder(Inst);
+                      Value *Ptr, IRBuilder<> &Builder) {
     return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
                       Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
   }
 
   /// Lower binary operators.
-  MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
+  MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI,
+                               IRBuilder<> &Builder) {
     Value *Lhs = Inst->getOperand(0);
     Value *Rhs = Inst->getOperand(1);
 
-    IRBuilder<> Builder(Inst);
-
     MatrixTy Result;
     MatrixTy A = getMatrix(Lhs, SI, Builder);
     MatrixTy B = getMatrix(Rhs, SI, Builder);
@@ -2265,11 +2262,10 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower unary operators.
-  MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
+  MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI,
+                              IRBuilder<> &Builder) {
     Value *Op = Inst->getOperand(0);
 
-    IRBuilder<> Builder(Inst);
-
     MatrixTy Result;
     MatrixTy M = getMatrix(Op, SI, Builder);
 
@@ -2293,11 +2289,10 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower cast instructions.
-  MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) {
+  MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape,
+                                IRBuilder<> &Builder) {
     Value *Op = Inst->getOperand(0);
 
-    IRBuilder<> Builder(Inst);
-
     MatrixTy Result;
     MatrixTy M = getMatrix(Op, Shape, Builder);
 
@@ -2315,13 +2310,12 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower selects.
-  MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) {
+  MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape,
+                           IRBuilder<> &Builder) {
     Value *Cond = Inst->getOperand(0);
     Value *OpA = Inst->getOperand(1);
     Value *OpB = Inst->getOperand(2);
 
-    IRBuilder<> Builder(Inst);
-
     MatrixTy Result;
     MatrixTy A = getMatrix(OpA, Shape, Builder);
     MatrixTy B = getMatrix(OpB, Shape, Builder);

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@jroelofs jroelofs merged commit 8ed43c4 into llvm:main Jun 16, 2025
9 checks passed
@jroelofs jroelofs deleted the jroelofs/lower-matrix-builder branch June 16, 2025 18:38
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants