@@ -792,7 +792,8 @@ class LowerMatrixIntrinsics {
792792 // / This creates and erases instructions as needed, and returns the newly
793793 // / created instruction while updating the iterator to avoid invalidation. If
794794 // / this returns nullptr, no new instruction was created.
795- Instruction *sinkTranspose (Instruction &I, BasicBlock::reverse_iterator &II) {
795+ Instruction *sinkTranspose (Instruction &I, BasicBlock::reverse_iterator &II,
796+ bool &Changed) {
796797 BasicBlock &BB = *I.getParent ();
797798 IRBuilder<> IB (&I);
798799 MatrixBuilder Builder (IB);
@@ -809,13 +810,15 @@ class LowerMatrixIntrinsics {
809810 updateShapeAndReplaceAllUsesWith (I, TATA);
810811 eraseFromParentAndMove (&I, II, BB);
811812 eraseFromParentAndMove (TA, II, BB);
813+ Changed = true ;
812814 return nullptr ;
813815 }
814816
815817 // k^T -> k
816818 if (isSplat (TA)) {
817819 updateShapeAndReplaceAllUsesWith (I, TA);
818820 eraseFromParentAndMove (&I, II, BB);
821+ Changed = true ;
819822 return nullptr ;
820823 }
821824
@@ -834,6 +837,7 @@ class LowerMatrixIntrinsics {
834837 updateShapeAndReplaceAllUsesWith (I, NewInst);
835838 eraseFromParentAndMove (&I, II, BB);
836839 eraseFromParentAndMove (TA, II, BB);
840+ Changed = true ;
837841 return NewInst;
838842 }
839843
@@ -859,6 +863,7 @@ class LowerMatrixIntrinsics {
859863 updateShapeAndReplaceAllUsesWith (I, NewInst);
860864 eraseFromParentAndMove (&I, II, BB);
861865 eraseFromParentAndMove (TA, II, BB);
866+ Changed = true ;
862867 return NewInst;
863868 }
864869
@@ -880,13 +885,14 @@ class LowerMatrixIntrinsics {
880885 updateShapeAndReplaceAllUsesWith (I, NewInst);
881886 eraseFromParentAndMove (&I, II, BB);
882887 eraseFromParentAndMove (TA, II, BB);
888+ Changed = true ;
883889 return NewInst;
884890 }
885891
886892 return nullptr ;
887893 }
888894
889- void liftTranspose (Instruction &I) {
895+ bool liftTranspose (Instruction &I) {
890896 // Erase dead Instructions after lifting transposes from binops.
891897 auto CleanupBinOp = [this ](Instruction &T, Value *A, Value *B) {
892898 if (T.use_empty ())
@@ -914,6 +920,7 @@ class LowerMatrixIntrinsics {
914920 R->getZExtValue ());
915921 updateShapeAndReplaceAllUsesWith (I, NewInst);
916922 CleanupBinOp (I, A, B);
923+ return true ;
917924 }
918925 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
919926 // the shape of the second transpose is different, there's a shape conflict
@@ -940,19 +947,22 @@ class LowerMatrixIntrinsics {
940947 ShapeMap[AddI] &&
941948 " Shape of updated addition doesn't match cached shape." );
942949 }
950+ return true ;
943951 }
952+ return false ;
944953 }
945954
946955 // / Try moving transposes in order to fold them away or into multiplies.
947- void optimizeTransposes () {
956+ bool optimizeTransposes () {
957+ bool Changed = false ;
948958 // First sink all transposes inside matmuls and adds, hoping that we end up
949959 // with NN, NT or TN variants.
950960 for (BasicBlock &BB : reverse (Func)) {
951961 for (auto II = BB.rbegin (); II != BB.rend ();) {
952962 Instruction &I = *II;
953963 // We may remove II. By default continue on the next/prev instruction.
954964 ++II;
955- if (Instruction *NewInst = sinkTranspose (I, II))
965+ if (Instruction *NewInst = sinkTranspose (I, II, Changed ))
956966 II = std::next (BasicBlock::reverse_iterator (NewInst));
957967 }
958968 }
@@ -961,9 +971,10 @@ class LowerMatrixIntrinsics {
961971 // to fold into consuming multiply or add.
962972 for (BasicBlock &BB : Func) {
963973 for (Instruction &I : llvm::make_early_inc_range (BB)) {
964- liftTranspose (I);
974+ Changed |= liftTranspose (I);
965975 }
966976 }
977+ return Changed;
967978 }
968979
969980 bool Visit () {
@@ -1006,15 +1017,15 @@ class LowerMatrixIntrinsics {
10061017 WorkList = propagateShapeBackward (WorkList);
10071018 }
10081019
1020+ bool Changed = false ;
10091021 if (!isMinimal ()) {
1010- optimizeTransposes ();
1022+ Changed |= optimizeTransposes ();
10111023 if (PrintAfterTransposeOpt) {
10121024 dbgs () << " Dump after matrix transpose optimization:\n " ;
10131025 Func.print (dbgs ());
10141026 }
10151027 }
10161028
1017- bool Changed = false ;
10181029 SmallVector<CallInst *, 16 > MaybeFusableInsts;
10191030 SmallVector<Instruction *, 16 > MatrixInsts;
10201031 SmallVector<IntrinsicInst *, 16 > LifetimeEnds;
@@ -1043,7 +1054,7 @@ class LowerMatrixIntrinsics {
10431054 if (!FusedInsts.contains (CI))
10441055 LowerMatrixMultiplyFused (CI, FusedInsts, LifetimeEnds);
10451056
1046- Changed = !FusedInsts.empty ();
1057+ Changed | = !FusedInsts.empty ();
10471058
10481059 // Fourth, lower remaining instructions with shape information.
10491060 for (Instruction *Inst : MatrixInsts) {
0 commit comments