Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b1b79aa
[Matrix] Propagate shape information through PHI instructions
jroelofs May 27, 2025
71b99d3
move formerly unsupported test to new home
jroelofs May 27, 2025
905c1e9
clang-format
jroelofs May 27, 2025
9ee44f0
add test for ConstantDataVector lowering
jroelofs May 27, 2025
169960d
move report_fatal_error outside of NDEBUG block
jroelofs May 27, 2025
b64a134
Merge branch 'main' into jroelofs/lower-matrix-phi
jroelofs May 28, 2025
18951cd
fix bad merge
jroelofs May 31, 2025
f8aea05
use col major load intrinsics
jroelofs Jun 2, 2025
e56b225
add tests for phi's consuming phi's, and phi's with more than two inputs
jroelofs Jun 2, 2025
ffbc73f
handle phi's more like other ops. instcombine will clean up after us
jroelofs Jun 2, 2025
15fd60b
handle phi's with shape mismatch
jroelofs Jun 2, 2025
88bd8cb
Merge branch 'main' into jroelofs/lower-matrix-phi
jroelofs Jun 2, 2025
655eb88
simplify getMatrix shim
jroelofs Jun 2, 2025
e262f76
test the other order of shape mismatch
jroelofs Jun 2, 2025
2c86c2f
clang-format
jroelofs Jun 9, 2025
86d3545
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 10, 2025
501414f
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 10, 2025
2e5b2d4
clang-format
jroelofs Jun 11, 2025
f048181
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 11, 2025
7511d17
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 12, 2025
2821467
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 12, 2025
4ba4e66
[Matrix] Fix a crash in VisitSelectInst due to iteration length mismatch
jroelofs Jun 12, 2025
4cbc839
review feedback: parens for initializer
jroelofs Jun 16, 2025
6f8ec49
review feedback: rename to GetMatrix
jroelofs Jun 16, 2025
104c126
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 16, 2025
67ead37
drop code for splitting constants, add test for it
jroelofs Jun 16, 2025
c9b3992
split phi's in two phases
jroelofs Jun 16, 2025
da17d10
Merge remote-tracking branch 'origin/main' into jroelofs/lower-matrix…
jroelofs Jun 16, 2025
dd55682
clang-format
jroelofs Jun 16, 2025
08be3b4
Merge branch 'main' into jroelofs/lower-matrix-phi
jroelofs Jun 17, 2025
5a991f7
rm constant.ll
jroelofs Jun 17, 2025
e9d0e62
test that shows reshape shuffles are inserted in the correct spot
jroelofs Jun 18, 2025
a760840
florian's suggestion is a little simpler: we already know it's a phi
jroelofs Jun 18, 2025
a62ef50
also use getInsertionPointAtDef() for non-inst phi operand reshape pl…
jroelofs Jun 18, 2025
47dd2e2
inline the lambda shim, to simplify
jroelofs Jun 18, 2025
0debc08
rm unnecessary include
jroelofs Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
Expand Down Expand Up @@ -230,6 +231,7 @@ static bool isUniformShape(Value *V) {
return true;

switch (I->getOpcode()) {
case Instruction::PHI:
case Instruction::FAdd:
case Instruction::FSub:
case Instruction::FMul: // Scalar multiply.
Expand Down Expand Up @@ -360,6 +362,33 @@ class LowerMatrixIntrinsics {
addVector(PoisonValue::get(FixedVectorType::get(
EltTy, isColumnMajor() ? NumRows : NumColumns)));
}
MatrixTy(ConstantData *Constant, const ShapeInfo &SI)
: IsColumnMajor(SI.IsColumnMajor) {
Type *EltTy = cast<VectorType>(Constant->getType())->getElementType();
Type *RowTy = VectorType::get(EltTy, ElementCount::getFixed(SI.NumRows));

for (unsigned J = 0, D = SI.getNumVectors(); J < D; ++J) {
if (auto *CDV = dyn_cast<ConstantDataVector>(Constant)) {
unsigned Width = SI.getStride();
size_t EltSize = EltTy->getPrimitiveSizeInBits() / 8;
StringRef Data = CDV->getRawDataValues().substr(
J * Width * EltSize, Width * EltSize);
addVector(ConstantDataVector::getRaw(Data, Width,
CDV->getElementType()));
} else if (isa<PoisonValue>(Constant))
addVector(PoisonValue::get(RowTy));
else if (isa<UndefValue>(Constant))
addVector(UndefValue::get(RowTy));
else if (isa<ConstantAggregateZero>(Constant))
addVector(ConstantAggregateZero::get(RowTy));
else {
#ifndef NDEBUG
Constant->dump();
report_fatal_error("unhandled ConstantData type");
#endif
}
}
}

Value *getVector(unsigned i) const { return Vectors[i]; }
Value *getColumn(unsigned i) const {
Expand Down Expand Up @@ -564,6 +593,27 @@ class LowerMatrixIntrinsics {
MatrixVal = M.embedInVector(Builder);
}

// If it's a PHI, split it now. We'll take care of fixing up the operands
// later once we're in VisitPHI.
if (auto *PHI = dyn_cast<PHINode>(MatrixVal)) {
auto *EltTy = cast<VectorType>(PHI->getType())->getElementType();
MatrixTy PhiM{SI.NumRows, SI.NumColumns, EltTy};

IRBuilder<>::InsertPointGuard IPG(Builder);
Builder.SetInsertPoint(PHI);
for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI)
PhiM.setVector(VI, Builder.CreatePHI(PhiM.getVectorTy(),
PHI->getNumIncomingValues(),
PHI->getName()));

Inst2ColumnMatrix[PHI] = PhiM;
return PhiM;
}

// If it's a constant, materialize the split version of it with this shape.
if (auto *IncomingConst = dyn_cast<ConstantData>(MatrixVal))
return MatrixTy(IncomingConst, SI);

// Otherwise split MatrixVal.
SmallVector<Value *, 16> SplitVecs;
for (unsigned MaskStart = 0;
Expand Down Expand Up @@ -1077,6 +1127,11 @@ class LowerMatrixIntrinsics {
Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
}

// Fifth, lower all the PHI's with shape information.
for (Instruction *Inst : MatrixInsts)
if (auto *PHI = dyn_cast<PHINode>(Inst))
Changed |= VisitPHI(PHI);

if (ORE) {
RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
RemarkGen.emitRemarks();
Expand Down Expand Up @@ -1349,7 +1404,8 @@ class LowerMatrixIntrinsics {
IRBuilder<> &Builder) {
auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
(void)inserted;
assert(inserted.second && "multiple matrix lowering mapping");
assert((inserted.second || isa<PHINode>(Inst)) &&
"multiple matrix lowering mapping");

ToRemove.push_back(Inst);
Value *Flattened = nullptr;
Expand Down Expand Up @@ -2133,6 +2189,41 @@ class LowerMatrixIntrinsics {
return true;
}

bool VisitPHI(PHINode *Inst) {
auto I = ShapeMap.find(Inst);
if (I == ShapeMap.end())
return false;

IRBuilder<> Builder(Inst);

MatrixTy PhiM = getMatrix(Inst, I->second, Builder);

for (unsigned IncomingI = 0, IncomingE = Inst->getNumIncomingValues();
IncomingI != IncomingE; ++IncomingI) {
Value *IncomingV = Inst->getIncomingValue(IncomingI);
BasicBlock *IncomingB = Inst->getIncomingBlock(IncomingI);

// getMatrix() may insert some instructions. The safe place to insert them
// is at the end of the parent block, where the register allocator would
// have inserted the copies that materialize the PHI.
if (auto *IncomingInst = dyn_cast<Instruction>(IncomingV))
Builder.SetInsertPoint(IncomingInst->getParent()->getTerminator());

MatrixTy OpM = getMatrix(IncomingV, I->second, Builder);

for (unsigned VI = 0, VE = PhiM.getNumVectors(); VI != VE; ++VI) {
PHINode *NewPHI = cast<PHINode>(PhiM.getVector(VI));
NewPHI->addIncoming(OpM.getVector(VI), IncomingB);
}
}

// finalizeLowering() may also insert instructions in some cases. The safe
// place for those is at the end of the initial block of PHIs.
Builder.SetInsertPoint(*Inst->getInsertionPointAfterDef());
finalizeLowering(Inst, PhiM, Builder);
return true;
}

/// Lower binary operators, if shape information is available.
bool VisitBinaryOperator(BinaryOperator *Inst) {
auto I = ShapeMap.find(Inst);
Expand Down
216 changes: 216 additions & 0 deletions llvm/test/Transforms/LowerMatrixIntrinsics/phi.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -matrix-allow-contract=false -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s

define void @matrix_phi(ptr %in1, ptr %in2, i32 %count, ptr %out) {
; CHECK-LABEL: @matrix_phi(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN1:%.*]], align 128
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN1]], i64 3
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN1]], i64 6
; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
; CHECK-NEXT: br label [[LOOP:%.*]]
; CHECK: loop:
; CHECK-NEXT: [[PHI9:%.*]] = phi <3 x double> [ [[COL_LOAD]], [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[PHI10:%.*]] = phi <3 x double> [ [[COL_LOAD1]], [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[PHI11:%.*]] = phi <3 x double> [ [[COL_LOAD3]], [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr double, ptr [[IN2]], i64 3
; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <3 x double>, ptr [[VEC_GEP5]], align 8
; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[IN2]], i64 6
; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <3 x double>, ptr [[VEC_GEP7]], align 16
; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI9]], [[COL_LOAD4]]
; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI10]], [[COL_LOAD6]]
; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI11]], [[COL_LOAD8]]
; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
; CHECK: exit:
; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
; CHECK-NEXT: [[VEC_GEP12:%.*]] = getelementptr double, ptr [[OUT]], i64 3
; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP12]], align 8
; CHECK-NEXT: [[VEC_GEP13:%.*]] = getelementptr double, ptr [[OUT]], i64 6
; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP13]], align 16
; CHECK-NEXT: ret void
;
entry:
%mat = load <9 x double>, ptr %in1
br label %loop

loop:
%phi = phi <9 x double> [%mat, %entry], [%sum, %loop]
%ctr = phi i32 [%count, %entry], [%dec, %loop]

%in2v = load <9 x double>, ptr %in2

; Give in2 the shape: 3 x 3
%in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
%in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)

%sum = fadd <9 x double> %phi, %in2tt

%dec = sub i32 %ctr, 1
%cmp = icmp eq i32 %dec, 0
br i1 %cmp, label %exit, label %loop

exit:
store <9 x double> %sum, ptr %out
ret void
}

define void @matrix_phi_zeroinitializer(ptr %in1, ptr %in2, i32 %count, ptr %out) {
; CHECK-LABEL: @matrix_phi_zeroinitializer(
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[LOOP:%.*]]
; CHECK: loop:
; CHECK-NEXT: [[PHI4:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ zeroinitializer, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6
; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]]
; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]]
; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]]
; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
; CHECK: exit:
; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3
; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8
; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6
; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16
; CHECK-NEXT: ret void
;
entry:
br label %loop

loop:
%phi = phi <9 x double> [zeroinitializer, %entry], [%sum, %loop]
%ctr = phi i32 [%count, %entry], [%dec, %loop]

%in2v = load <9 x double>, ptr %in2

; Give in2 the shape: 3 x 3
%in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
%in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)

%sum = fadd <9 x double> %phi, %in2tt

%dec = sub i32 %ctr, 1
%cmp = icmp eq i32 %dec, 0
br i1 %cmp, label %exit, label %loop

exit:
store <9 x double> %sum, ptr %out
ret void
}

define void @matrix_phi_undef(ptr %in1, ptr %in2, i32 %count, ptr %out) {
; CHECK-LABEL: @matrix_phi_undef(
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[LOOP:%.*]]
; CHECK: loop:
; CHECK-NEXT: [[PHI4:%.*]] = phi <3 x double> [ undef, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ undef, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ undef, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6
; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]]
; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]]
; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]]
; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
; CHECK: exit:
; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3
; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8
; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6
; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16
; CHECK-NEXT: ret void
;
entry:
br label %loop

loop:
%phi = phi <9 x double> [undef, %entry], [%sum, %loop]
%ctr = phi i32 [%count, %entry], [%dec, %loop]

%in2v = load <9 x double>, ptr %in2

; Give in2 the shape: 3 x 3
%in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
%in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)

%sum = fadd <9 x double> %phi, %in2tt

%dec = sub i32 %ctr, 1
%cmp = icmp eq i32 %dec, 0
br i1 %cmp, label %exit, label %loop

exit:
store <9 x double> %sum, ptr %out
ret void
}

define void @matrix_phi_poison(ptr %in1, ptr %in2, i32 %count, ptr %out) {
; CHECK-LABEL: @matrix_phi_poison(
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[LOOP:%.*]]
; CHECK: loop:
; CHECK-NEXT: [[PHI4:%.*]] = phi <3 x double> [ poison, [[ENTRY:%.*]] ], [ [[TMP0:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[PHI5:%.*]] = phi <3 x double> [ poison, [[ENTRY]] ], [ [[TMP1:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[PHI6:%.*]] = phi <3 x double> [ poison, [[ENTRY]] ], [ [[TMP2:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[CTR:%.*]] = phi i32 [ [[COUNT:%.*]], [[ENTRY]] ], [ [[DEC:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <3 x double>, ptr [[IN2:%.*]], align 128
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN2]], i64 3
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <3 x double>, ptr [[VEC_GEP]], align 8
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[IN2]], i64 6
; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <3 x double>, ptr [[VEC_GEP2]], align 16
; CHECK-NEXT: [[TMP0]] = fadd <3 x double> [[PHI4]], [[COL_LOAD]]
; CHECK-NEXT: [[TMP1]] = fadd <3 x double> [[PHI5]], [[COL_LOAD1]]
; CHECK-NEXT: [[TMP2]] = fadd <3 x double> [[PHI6]], [[COL_LOAD3]]
; CHECK-NEXT: [[DEC]] = sub i32 [[CTR]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[DEC]], 0
; CHECK-NEXT: br i1 [[CMP]], label [[EXIT:%.*]], label [[LOOP]]
; CHECK: exit:
; CHECK-NEXT: store <3 x double> [[TMP0]], ptr [[OUT:%.*]], align 128
; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr double, ptr [[OUT]], i64 3
; CHECK-NEXT: store <3 x double> [[TMP1]], ptr [[VEC_GEP7]], align 8
; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr double, ptr [[OUT]], i64 6
; CHECK-NEXT: store <3 x double> [[TMP2]], ptr [[VEC_GEP8]], align 16
; CHECK-NEXT: ret void
;
entry:
br label %loop

loop:
%phi = phi <9 x double> [poison, %entry], [%sum, %loop]
%ctr = phi i32 [%count, %entry], [%dec, %loop]

%in2v = load <9 x double>, ptr %in2

; Give in2 the shape: 3 x 3
%in2t = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2v, i32 3, i32 3)
%in2tt = call <9 x double> @llvm.matrix.transpose(<9 x double> %in2t, i32 3, i32 3)

%sum = fadd <9 x double> %phi, %in2tt

%dec = sub i32 %ctr, 1
%cmp = icmp eq i32 %dec, 0
br i1 %cmp, label %exit, label %loop

exit:
store <9 x double> %sum, ptr %out
ret void
}
Loading
Loading