From 4040d3fc777ff8d5b212e77fac604f60d997475a Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Wed, 28 May 2025 16:16:50 -0700 Subject: [PATCH 1/5] [Matrix] Propagate shape information through Select insts --- .../Scalar/LowerMatrixIntrinsics.cpp | 49 ++++++++++++- .../LowerMatrixIntrinsics/select.ll | 68 +++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/select.ll diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 756a72e6d97bc..6c364f057481a 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -269,6 +269,15 @@ computeShapeInfoForInst(Instruction *I, return OpShape->second; } + if (isa(I)) { + auto OpShape = ShapeMap.find(I->getOperand(1)); + if (OpShape != ShapeMap.end()) + return OpShape->second; + OpShape = ShapeMap.find(I->getOperand(2)); + if (OpShape != ShapeMap.end()) + return OpShape->second; + } + if (isUniformShape(I)) { // Find the first operand that has a known shape and use that. for (auto &Op : I->operands()) { @@ -623,7 +632,8 @@ class LowerMatrixIntrinsics { default: return false; } - return isUniformShape(V) || isa(V) || isa(V); + return isUniformShape(V) || isa(V) || isa(V) || + isa(V); } /// Propagate the shape information of instructions to their users. @@ -710,6 +720,12 @@ class LowerMatrixIntrinsics { } else if (isa(V)) { // Nothing to do. We forward-propagated to this so we would just // backward propagate to an instruction with an already known shape. + } else if (auto *Select = dyn_cast(V)) { + ShapeInfo Shape = ShapeMap[V]; + if (setShapeInfo(Select->getOperand(1), Shape)) + pushInstruction(Select, WorkList); + if (setShapeInfo(Select->getOperand(2), Shape)) + pushInstruction(Select, WorkList); } else if (isUniformShape(V)) { // Propagate to all operands. ShapeInfo Shape = ShapeMap[V]; @@ -1068,6 +1084,8 @@ class LowerMatrixIntrinsics { Changed |= VisitBinaryOperator(BinOp); if (auto *UnOp = dyn_cast(Inst)) Changed |= VisitUnaryOperator(UnOp); + if (auto *Select = dyn_cast(Inst)) + Changed |= VisitSelectInst(Select); if (match(Inst, m_Load(m_Value(Op1)))) Changed |= VisitLoad(cast(Inst), Op1, Builder); else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) @@ -2198,6 +2216,35 @@ class LowerMatrixIntrinsics { return true; } + /// Lower selects, if shape information is available. + bool VisitSelectInst(SelectInst *Inst) { + auto I = ShapeMap.find(Inst); + if (I == ShapeMap.end()) + return false; + + Value *Cond = Inst->getOperand(0); + Value *OpA = Inst->getOperand(1); + Value *OpB = Inst->getOperand(2); + + IRBuilder<> Builder(Inst); + ShapeInfo &Shape = I->second; + + MatrixTy Result; + MatrixTy A = getMatrix(OpA, Shape, Builder); + MatrixTy B = getMatrix(OpB, Shape, Builder); + + for (unsigned I = 0; I < Shape.getNumVectors(); ++I) { + auto *Sel = Builder.CreateSelect(Cond, A.getVector(I), B.getVector(I)); + Result.addVector(Sel); + } + + finalizeLowering(Inst, + Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()), + Builder); + return true; + } + /// Helper to linearize a matrix expression tree into a string. Currently /// matrix expressions are linarized by starting at an expression leaf and /// linearizing bottom up. diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll new file mode 100644 index 0000000000000..507b02a04f47f --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll @@ -0,0 +1,68 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define void @select_2x2_bot(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) { +; CHECK-LABEL: @select_2x2_bot( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]] +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]] +; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 4 +; CHECK-NEXT: ret void +; + %lhsv = load <4 x float>, ptr %lhs + %rhsv = load <4 x float>, ptr %rhs + %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv + call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @select_2x2_lhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) { +; CHECK-LABEL: @select_2x2_lhs( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]] +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]] +; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8 +; CHECK-NEXT: ret void +; + %lhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %lhs, i64 2, i1 false, i32 2, i32 2) + %rhsv = load <4 x float>, ptr %rhs + %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv + store <4 x float> %op, ptr %out + ret void +} + +define void @select_2x2_rhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) { +; CHECK-LABEL: @select_2x2_rhs( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[RHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS1]], i64 2 +; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]] +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]] +; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: ret void +; + %lhsv = load <4 x float>, ptr %lhs + %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2) + %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv + store <4 x float> %op, ptr %out + ret void +} From c0c63f392205e42a1421c475c8d120d49ba9bf1d Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Thu, 5 Jun 2025 13:46:58 -0700 Subject: [PATCH 2/5] select with mismatched shape --- .../Scalar/LowerMatrixIntrinsics.cpp | 14 ++-- .../LowerMatrixIntrinsics/select.ll | 66 +++++++++++++------ 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index ef57faf911e31..7da55a2a9a355 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -271,7 +271,9 @@ computeShapeInfoForInst(Instruction *I, } if (auto *Select = dyn_cast(I)) { - for (Use &Op : Select->getCondition()->getType()->isVectorTy() ? I->operands() : drop_begin(I->operands())) { + Type *CondTy = Select->getCondition()->getType(); + for (Use &Op : CondTy->isVectorTy() ? Select->operands() + : drop_begin(Select->operands())) { auto OpShape = ShapeMap.find(Op); if (OpShape != ShapeMap.end()) return OpShape->second; @@ -719,10 +721,12 @@ class LowerMatrixIntrinsics { // backward propagate to an instruction with an already known shape. } else if (auto *Select = dyn_cast(V)) { ShapeInfo Shape = ShapeMap[V]; - if (setShapeInfo(Select->getOperand(1), Shape)) - pushInstruction(Select, WorkList); - if (setShapeInfo(Select->getOperand(2), Shape)) - pushInstruction(Select, WorkList); + Type *CondTy = Select->getCondition()->getType(); + for (Use &Op : CondTy->isVectorTy() ? Select->operands() + : drop_begin(Select->operands())) { + if (setShapeInfo(Op, Shape)) + pushInstruction(Select, WorkList); + } } else if (isUniformShape(V)) { // Propagate to all operands. ShapeInfo Shape = ShapeMap[V]; diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll index 31c34e24c540d..56dca7bb985d3 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll @@ -67,39 +67,41 @@ define void @select_2x2_rhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) { ret void } -define void @select_2x2_vcond(<4 x i1> %cond, ptr %lhs, ptr %rhs, ptr %out) { -; CHECK-LABEL: @select_2x2_vcond( +define void @select_2x2_vcond_shape1(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) { +; CHECK-LABEL: @select_2x2_vcond_shape1( ; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2 ; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 -; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4 -; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2 -; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 4 -; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i1> [[COND:%.*]], <4 x i1> poison, <2 x i32> -; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <4 x i1> [[COND]], <4 x i1> poison, <2 x i32> -; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[SPLIT]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]] -; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[SPLIT5]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]] +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1 +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS1]], i64 2 +; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]] +; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]] ; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16 -; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[OUT]], i64 2 -; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP6]], align 8 +; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr float, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP8]], align 8 ; CHECK-NEXT: ret void ; %lhsv = load <4 x float>, ptr %lhs + %condv = load <4 x i1>, ptr %cond %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2) - %op = select <4 x i1> %cond, <4 x float> %lhsv, <4 x float> %rhsv + %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv store <4 x float> %op, ptr %out ret void } -define void @select_2x2_vcond_shape(ptr %lhs, ptr %rhs, ptr %out) { -; CHECK-LABEL: @select_2x2_vcond_shape( +define void @select_2x2_vcond_shape2(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) { +; CHECK-LABEL: @select_2x2_vcond_shape2( ; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2 ; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 -; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1 -; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[COND:%.*]], align 1 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[COND]], i64 2 ; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1 -; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS]], align 4 +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4 ; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS]], i64 2 ; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4 ; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]] @@ -110,9 +112,35 @@ define void @select_2x2_vcond_shape(ptr %lhs, ptr %rhs, ptr %out) { ; CHECK-NEXT: ret void ; %lhsv = load <4 x float>, ptr %lhs - %cond = call <4 x i1> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2) + %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 2, i1 false, i32 2, i32 2) + %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2) + %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv + store <4 x float> %op, ptr %out + ret void +} + +define void @select_2x2_vcond_shape3(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) { +; CHECK-LABEL: @select_2x2_vcond_shape3( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP4:%.*]] = getelementptr float, ptr [[RHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[VEC_GEP4]], align 4 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> +; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[SPLIT]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD3]] +; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[SPLIT6]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD5]] +; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr float, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP7]], align 8 +; CHECK-NEXT: ret void +; + %lhsv = load <4 x float>, ptr %lhs + %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 4, i1 false, i32 4, i32 1) %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2) - %op = select <4 x i1> %cond, <4 x float> %lhsv, <4 x float> %rhsv + %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv store <4 x float> %op, ptr %out ret void } From 5a4a6a507f57da9fc4d4081385a9f8339929da87 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Thu, 5 Jun 2025 13:48:15 -0700 Subject: [PATCH 3/5] no return value --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 7da55a2a9a355..deff2908b4902 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1087,7 +1087,7 @@ class LowerMatrixIntrinsics { else if (CallInst *CInst = dyn_cast(Inst)) VisitCallInst(CInst); else if (auto *Select = dyn_cast(Inst)) - Changed |= VisitSelectInst(Select, SI); + VisitSelectInst(Select, SI); else if (match(Inst, m_Load(m_Value(Op1)))) VisitLoad(cast(Inst), SI, Op1, Builder); else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) @@ -2200,7 +2200,7 @@ class LowerMatrixIntrinsics { } /// Lower selects. - bool VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) { + void VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) { Value *Cond = Inst->getOperand(0); Value *OpA = Inst->getOperand(1); Value *OpB = Inst->getOperand(2); @@ -2228,7 +2228,6 @@ class LowerMatrixIntrinsics { Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * Result.getNumVectors()), Builder); - return true; } /// Helper to linearize a matrix expression tree into a string. Currently From 6693f56f75e8d67524e1b2cb8cbf6455d2fcf17d Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Thu, 5 Jun 2025 13:52:54 -0700 Subject: [PATCH 4/5] clean up condition lookup --- .../Transforms/Scalar/LowerMatrixIntrinsics.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index deff2908b4902..3b4f1bf7ff67d 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2211,19 +2211,20 @@ class LowerMatrixIntrinsics { MatrixTy A = getMatrix(OpA, Shape, Builder); MatrixTy B = getMatrix(OpB, Shape, Builder); + Value *CondV[2]; if (isa(Cond->getType())) { MatrixTy C = getMatrix(Cond, Shape, Builder); - for (unsigned I = 0; I < Shape.getNumVectors(); ++I) { - auto *Sel = Builder.CreateSelect(C.getVector(I), A.getVector(I), B.getVector(I)); - Result.addVector(Sel); - } + CondV[0] = C.getVector(0); + CondV[1] = C.getVector(1); } else { - for (unsigned I = 0; I < Shape.getNumVectors(); ++I) { - auto *Sel = Builder.CreateSelect(Cond, A.getVector(I), B.getVector(I)); - Result.addVector(Sel); - } + CondV[0] = Cond; + CondV[1] = Cond; } + for (unsigned I = 0, E = Shape.getNumVectors(); I != E; ++I) + Result.addVector( + Builder.CreateSelect(CondV[I], A.getVector(I), B.getVector(I))); + finalizeLowering(Inst, Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * Result.getNumVectors()), From 1a4faa4c9e43aa47319ea9d917b65a72f2bb67ce Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Tue, 10 Jun 2025 11:16:14 -0700 Subject: [PATCH 5/5] review feedback: don't take shape info from select conditions --- .../Scalar/LowerMatrixIntrinsics.cpp | 30 +++++-------------- .../LowerMatrixIntrinsics/select.ll | 6 ++-- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 3b4f1bf7ff67d..3a5b0f8fdb415 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -270,19 +270,11 @@ computeShapeInfoForInst(Instruction *I, return OpShape->second; } - if (auto *Select = dyn_cast(I)) { - Type *CondTy = Select->getCondition()->getType(); - for (Use &Op : CondTy->isVectorTy() ? Select->operands() - : drop_begin(Select->operands())) { - auto OpShape = ShapeMap.find(Op); - if (OpShape != ShapeMap.end()) - return OpShape->second; - } - } - - if (isUniformShape(I)) { + if (isUniformShape(I) || isa(I)) { + auto Ops = I->operands(); + auto ShapedOps = isa(I) ? drop_begin(Ops) : Ops; // Find the first operand that has a known shape and use that. - for (auto &Op : I->operands()) { + for (auto &Op : ShapedOps) { auto OpShape = ShapeMap.find(Op.get()); if (OpShape != ShapeMap.end()) return OpShape->second; @@ -719,18 +711,12 @@ class LowerMatrixIntrinsics { } else if (isa(V)) { // Nothing to do. We forward-propagated to this so we would just // backward propagate to an instruction with an already known shape. - } else if (auto *Select = dyn_cast(V)) { - ShapeInfo Shape = ShapeMap[V]; - Type *CondTy = Select->getCondition()->getType(); - for (Use &Op : CondTy->isVectorTy() ? Select->operands() - : drop_begin(Select->operands())) { - if (setShapeInfo(Op, Shape)) - pushInstruction(Select, WorkList); - } - } else if (isUniformShape(V)) { + } else if (isUniformShape(V) || isa(V)) { + auto Ops = cast(V)->operands(); + auto ShapedOps = isa(V) ? drop_begin(Ops) : Ops; // Propagate to all operands. ShapeInfo Shape = ShapeMap[V]; - for (Use &U : cast(V)->operands()) { + for (Use &U : ShapedOps) { if (setShapeInfo(U.get(), Shape)) pushInstruction(U.get(), WorkList); } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll index 56dca7bb985d3..70b0dfdb3e7e8 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll @@ -72,12 +72,12 @@ define void @select_2x2_vcond_shape1(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) { ; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2 ; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 -; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1 -; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2 -; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1 +; CHECK-NEXT: [[CONDV:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1 ; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4 ; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS1]], i64 2 ; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4 +; CHECK-NEXT: [[COL_LOAD2:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> +; CHECK-NEXT: [[COL_LOAD4:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> ; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]] ; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]] ; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16