Skip to content

Conversation

@jroelofs
Copy link
Contributor

When possible, this avoids a bunch of shuffles in & out of the flattened
layout.

@llvmbot
Copy link
Member

llvmbot commented May 29, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jon Roelofs (jroelofs)

Changes

When possible, this avoids a bunch of shuffles in & out of the flattened
layout.


Patch is 21.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142055.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+185-4)
  • (added) llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll (+41)
  • (added) llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll (+200)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..4f997f2133527 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -30,16 +30,19 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/CFG.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/MatrixBuilder.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
@@ -1101,6 +1104,7 @@ class LowerMatrixIntrinsics {
     if (!PoisonedInsts.empty()) {
       // If we didn't remove all poisoned instructions, it's a hard error.
       dbgs() << "Poisoned but present instructions:\n";
+      Func.dump();
       for (auto *I : PoisonedInsts)
         dbgs() << *I << "\n";
       llvm_unreachable("Poisoned but instruction not removed");
@@ -1337,6 +1341,155 @@ class LowerMatrixIntrinsics {
     return Builder.CreateAdd(Sum, Mul);
   }
 
+  bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
+    Value *Op0 = Inst->getOperand(0);
+    auto *VTy = cast<VectorType>(Op0->getType());
+
+    if (VTy->getElementCount().getKnownMinValue() < Index) {
+      Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
+      Inst->eraseFromParent();
+      return true;
+    }
+
+    auto *I = Inst2ColumnMatrix.find(Op0);
+    if (I == Inst2ColumnMatrix.end())
+      return false;
+
+    const MatrixTy &M = I->second;
+
+    IRBuilder<> Builder(Inst);
+    Inst->setOperand(0, M.getVector(Index / M.getStride()));
+    Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
+    return true;
+  }
+
+  bool VisitReduce(IntrinsicInst *Inst) {
+    FastMathFlags FMF = getFastMathFlags(Inst);
+
+    if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
+      return false;
+
+    Value *Start = nullptr;
+    Value *Op = nullptr;
+    switch (Inst->getIntrinsicID()) {
+    case Intrinsic::vector_reduce_fadd:
+    case Intrinsic::vector_reduce_fmul:
+      Start = Inst->getOperand(0);
+      Op = Inst->getOperand(1);
+      break;
+    case Intrinsic::vector_reduce_fmax:
+    case Intrinsic::vector_reduce_fmaximum:
+    case Intrinsic::vector_reduce_fmin:
+    case Intrinsic::vector_reduce_fminimum:
+    case Intrinsic::vector_reduce_add:
+    case Intrinsic::vector_reduce_and:
+    case Intrinsic::vector_reduce_mul:
+    case Intrinsic::vector_reduce_or:
+    case Intrinsic::vector_reduce_xor:
+      Op = Inst->getOperand(0);
+      break;
+    default:
+      llvm_unreachable("unexpected intrinsic");
+    }
+
+    switch (Inst->getIntrinsicID()) {
+    case Intrinsic::vector_reduce_fadd: {
+      if (!match(Start, m_AnyZeroFP()))
+        return false;
+    } break;
+    case Intrinsic::vector_reduce_fmul: {
+      if (!match(Start, m_FPOne()))
+        return false;
+    } break;
+    default:
+      break;
+    }
+
+    auto *I = Inst2ColumnMatrix.find(Op);
+    if (I == Inst2ColumnMatrix.end())
+      return false;
+
+    IRBuilder<> Builder(Inst);
+
+    const MatrixTy &M = I->second;
+
+    auto CreateVReduce = [&](Value *LHS, Value *RHS) {
+      switch (Inst->getIntrinsicID()) {
+      case Intrinsic::vector_reduce_add:
+        return Builder.CreateAdd(LHS, RHS);
+      case Intrinsic::vector_reduce_and:
+        return Builder.CreateAnd(LHS, RHS);
+      case Intrinsic::vector_reduce_fadd:
+        return Builder.CreateFAdd(LHS, RHS);
+      case Intrinsic::vector_reduce_fmax:
+        return Builder.CreateMaximum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fmaximum:
+        return Builder.CreateMaximumNum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fmin:
+        return Builder.CreateMinimum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fminimum:
+        return Builder.CreateMinimumNum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fmul:
+        return Builder.CreateFMul(LHS, RHS);
+      case Intrinsic::vector_reduce_mul:
+        return Builder.CreateMul(LHS, RHS);
+      case Intrinsic::vector_reduce_or:
+        return Builder.CreateOr(LHS, RHS);
+      case Intrinsic::vector_reduce_xor:
+        return Builder.CreateXor(LHS, RHS);
+      default:
+        llvm_unreachable("unexpected intrinsic");
+      }
+    };
+
+    Value *ResultV;
+    if (Inst->getIntrinsicID() == Intrinsic::vector_reduce_fadd ||
+        Inst->getIntrinsicID() == Intrinsic::vector_reduce_fmul) {
+      ResultV = Builder.CreateVectorSplat(ElementCount::getFixed(M.getStride()), Start);
+      for (unsigned VI = 0, VE = M.getNumVectors(); VI != VE; VI++)
+        ResultV = CreateVReduce(ResultV, M.getVector(VI));
+    } else {
+      ResultV = M.getVector(0);
+      for (unsigned VI = 1, VE = M.getNumVectors(); VI != VE; VI++)
+        ResultV = CreateVReduce(ResultV, M.getVector(VI));
+    }
+
+    auto CreateHReduce = [&](Value *V) {
+      switch (Inst->getIntrinsicID()) {
+      case Intrinsic::vector_reduce_add:
+        return Builder.CreateAddReduce(V);
+      case Intrinsic::vector_reduce_and:
+        return Builder.CreateAndReduce(V);
+      case Intrinsic::vector_reduce_fadd:
+        return Builder.CreateFAddReduce(Start, V);
+      case Intrinsic::vector_reduce_fmax:
+        return Builder.CreateFPMaxReduce(V);
+      case Intrinsic::vector_reduce_fmaximum:
+        return Builder.CreateFPMaximumReduce(V);
+      case Intrinsic::vector_reduce_fmin:
+        return Builder.CreateFPMinReduce(V);
+      case Intrinsic::vector_reduce_fminimum:
+        return Builder.CreateFPMinimumReduce(V);
+      case Intrinsic::vector_reduce_fmul:
+        return Builder.CreateFMulReduce(Start, V);
+      case Intrinsic::vector_reduce_mul:
+        return Builder.CreateMulReduce(V);
+      case Intrinsic::vector_reduce_or:
+        return Builder.CreateOrReduce(V);
+      case Intrinsic::vector_reduce_xor:
+        return Builder.CreateXorReduce(V);
+      default:
+        llvm_unreachable("unexpected intrinsic");
+      }
+    };
+
+    Value *Result = CreateHReduce(ResultV);
+    Inst->replaceAllUsesWith(Result);
+    Result->takeName(Inst);
+    Inst->eraseFromParent();
+    return true;
+  }
+
   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
   /// users with shape information, there's nothing to do: they will use the
   /// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1504,39 @@ class LowerMatrixIntrinsics {
     ToRemove.push_back(Inst);
     Value *Flattened = nullptr;
     for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
-      if (!ShapeMap.contains(U.getUser())) {
-        if (!Flattened)
-          Flattened = Matrix.embedInVector(Builder);
-        U.set(Flattened);
+      if (ShapeMap.contains(U.getUser()))
+        continue;
+
+      Value *Op1;
+      uint64_t Index;
+      if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
+        if (VisitExtractElt(cast<ExtractElementInst>(U.getUser()), Index))
+          continue;
+
+      if (auto *Intr = dyn_cast<IntrinsicInst>(U.getUser())) {
+        switch (Intr->getIntrinsicID()) {
+        case Intrinsic::vector_reduce_add:
+        case Intrinsic::vector_reduce_and:
+        case Intrinsic::vector_reduce_fadd:
+        case Intrinsic::vector_reduce_fmax:
+        case Intrinsic::vector_reduce_fmaximum:
+        case Intrinsic::vector_reduce_fmin:
+        case Intrinsic::vector_reduce_fminimum:
+        case Intrinsic::vector_reduce_fmul:
+        case Intrinsic::vector_reduce_mul:
+        case Intrinsic::vector_reduce_or:
+        case Intrinsic::vector_reduce_xor:
+          if (VisitReduce(Intr))
+            continue;
+          break;
+        default:
+          break;
+        }
       }
+
+      if (!Flattened)
+        Flattened = Matrix.embedInVector(Builder);
+      U.set(Flattened);
     }
   }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
new file mode 100644
index 0000000000000..db5444ca036ae
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -0,0 +1,41 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define float @extract_static(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
+; CHECK-NEXT:    ret float [[EXTRACT]]
+;
+  %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+  %extract = extractelement <4 x float> %inv, i32 3
+  ret float %extract
+}
+
+define float @extract_static_outofbounds(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static_outofbounds(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    ret float poison
+;
+  %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+  %extract = extractelement <4 x float> %inv, i32 5
+  ret float %extract
+}
+
+define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) {
+; CHECK-LABEL: @extract_dynamic(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]]
+; CHECK-NEXT:    ret float [[EXTRACT]]
+;
+  %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+  %extract = extractelement <4 x float> %inv, i32 %idx
+  ret float %extract
+}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
new file mode 100644
index 0000000000000..41f65e01fec79
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
@@ -0,0 +1,200 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define i32 @reduce_add(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_add(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.add(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_and(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_and(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.and(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_or(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_or(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = or <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.or(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_mul(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_mul(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = mul <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.mul(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_xor(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_xor(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = xor <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.xor(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define float @reduce_fadd(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fadd_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x float> zeroinitializer, [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fadd_weirdstart(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_weirdstart(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fadd.v8f32(float 1.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fadd(float 1., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmul_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmul_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul <4 x float> splat (float 1.000000e+00), [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmul.v4f32(float 1.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmul(float 1., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmul_weirdstart(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmul_weirdstart(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fmul.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmul(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmax_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmax_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.maximum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmax(<8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmaximum_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmaximum_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.maximumnu...
[truncated]

@jroelofs jroelofs force-pushed the jroelofs/lower-matrix-reduce branch from e1368d5 to 897768b Compare May 29, 2025 22:52
@github-actions
Copy link

github-actions bot commented May 29, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@jroelofs jroelofs marked this pull request as draft May 29, 2025 22:53
@jroelofs jroelofs force-pushed the jroelofs/lower-matrix-reduce branch from 897768b to 51e0487 Compare May 29, 2025 22:55
@jroelofs jroelofs marked this pull request as ready for review June 10, 2025 19:15
@jroelofs
Copy link
Contributor Author

ping

@jroelofs jroelofs requested a review from fhahn June 18, 2025 17:44
Comment on lines +46 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be driven by a cost function? The new lowering seems strictly worse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've rebased the patch stack to pseudo pre-land the test, so the diff is more obvious.

aff9068#diff-6a09d32782efca5b9899b5ed357c5befd70d65f461a2ae64284f2504975f8948

Hm, yeah, that is bad: https://llvm.godbolt.org/z/7EbfGer3E

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(It's not completely now, we already get the similarly bad bevhavior for other ops, but with the reductions it's probably even worth, so it might be worth tackling before going further down that road)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, wait, it's better, at least on cyclone (the only thing we have a scheduling model for at the moment)
old: https://llvm.godbolt.org/z/foash7aMq
new: https://llvm.godbolt.org/z/99aTd41Mr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any ideas on what the cost function would be, if we had one?

@jroelofs jroelofs force-pushed the jroelofs/lower-matrix-reduce branch from e1cb5a1 to aff9068 Compare June 23, 2025 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants