@@ -1056,19 +1056,20 @@ class LowerMatrixIntrinsics {
10561056
10571057 IRBuilder<> Builder (Inst);
10581058
1059- if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1060- Changed |= VisitCallInst (CInst);
1059+ const ShapeInfo &SI = ShapeMap.at (Inst);
10611060
10621061 Value *Op1;
10631062 Value *Op2;
10641063 if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1065- VisitBinaryOperator (BinOp);
1064+ VisitBinaryOperator (BinOp, SI );
10661065 else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1067- VisitUnaryOperator (UnOp);
1066+ VisitUnaryOperator (UnOp, SI);
1067+ else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1068+ VisitCallInst (CInst);
10681069 else if (match (Inst, m_Load (m_Value (Op1))))
1069- VisitLoad (cast<LoadInst>(Inst), Op1, Builder);
1070+ VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder);
10701071 else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
1071- VisitStore (cast<StoreInst>(Inst), Op1, Op2, Builder);
1072+ VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
10721073 else
10731074 continue ;
10741075 Changed = true ;
@@ -1109,10 +1110,10 @@ class LowerMatrixIntrinsics {
11091110 return Changed;
11101111 }
11111112
1112- // / Replace intrinsic calls
1113- bool VisitCallInst (CallInst *Inst) {
1114- if (! Inst->getCalledFunction () || !Inst-> getCalledFunction ()-> isIntrinsic ())
1115- return false ;
1113+ // / Replace intrinsic calls.
1114+ void VisitCallInst (CallInst *Inst) {
1115+ assert ( Inst->getCalledFunction () &&
1116+ Inst-> getCalledFunction ()-> isIntrinsic ()) ;
11161117
11171118 switch (Inst->getCalledFunction ()->getIntrinsicID ()) {
11181119 case Intrinsic::matrix_multiply:
@@ -1128,9 +1129,9 @@ class LowerMatrixIntrinsics {
11281129 LowerColumnMajorStore (Inst);
11291130 break ;
11301131 default :
1131- return false ;
1132+ llvm_unreachable (
1133+ " only intrinsics supporting shape info should be seen here" );
11321134 }
1133- return true ;
11341135 }
11351136
11361137 // / Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -2107,48 +2108,36 @@ class LowerMatrixIntrinsics {
21072108 Builder);
21082109 }
21092110
2110- // / Lower load instructions, if shape information is available.
2111- void VisitLoad (LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
2112- auto I = ShapeMap.find (Inst);
2113- assert (I != ShapeMap.end () &&
2114- " must only visit instructions with shape info" );
2115- LowerLoad (Inst, Ptr, Inst->getAlign (),
2116- Builder.getInt64 (I->second .getStride ()), Inst->isVolatile (),
2117- I->second );
2111+ // / Lower load instructions.
2112+ void VisitLoad (LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2113+ IRBuilder<> &Builder) {
2114+ LowerLoad (Inst, Ptr, Inst->getAlign (), Builder.getInt64 (SI.getStride ()),
2115+ Inst->isVolatile (), SI);
21182116 }
21192117
2120- void VisitStore (StoreInst *Inst, Value *StoredVal, Value *Ptr,
2121- IRBuilder<> &Builder) {
2122- auto I = ShapeMap.find (Inst);
2123- assert (I != ShapeMap.end () &&
2124- " must only visit instructions with shape info" );
2118+ void VisitStore (StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2119+ Value *Ptr, IRBuilder<> &Builder) {
21252120 LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2126- Builder.getInt64 (I->second .getStride ()), Inst->isVolatile (),
2127- I->second );
2121+ Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI);
21282122 }
21292123
2130- // / Lower binary operators, if shape information is available.
2131- void VisitBinaryOperator (BinaryOperator *Inst) {
2132- auto I = ShapeMap.find (Inst);
2133- assert (I != ShapeMap.end () &&
2134- " must only visit instructions with shape info" );
2135-
2124+ // / Lower binary operators.
2125+ void VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI) {
21362126 Value *Lhs = Inst->getOperand (0 );
21372127 Value *Rhs = Inst->getOperand (1 );
21382128
21392129 IRBuilder<> Builder (Inst);
2140- ShapeInfo &Shape = I->second ;
21412130
21422131 MatrixTy Result;
2143- MatrixTy A = getMatrix (Lhs, Shape , Builder);
2144- MatrixTy B = getMatrix (Rhs, Shape , Builder);
2132+ MatrixTy A = getMatrix (Lhs, SI , Builder);
2133+ MatrixTy B = getMatrix (Rhs, SI , Builder);
21452134 assert (A.isColumnMajor () == B.isColumnMajor () &&
21462135 Result.isColumnMajor () == A.isColumnMajor () &&
21472136 " operands must agree on matrix layout" );
21482137
21492138 Builder.setFastMathFlags (getFastMathFlags (Inst));
21502139
2151- for (unsigned I = 0 ; I < Shape .getNumVectors (); ++I)
2140+ for (unsigned I = 0 ; I < SI .getNumVectors (); ++I)
21522141 Result.addVector (Builder.CreateBinOp (Inst->getOpcode (), A.getVector (I),
21532142 B.getVector (I)));
21542143
@@ -2158,19 +2147,14 @@ class LowerMatrixIntrinsics {
21582147 Builder);
21592148 }
21602149
2161- // / Lower unary operators, if shape information is available.
2162- void VisitUnaryOperator (UnaryOperator *Inst) {
2163- auto I = ShapeMap.find (Inst);
2164- assert (I != ShapeMap.end () &&
2165- " must only visit instructions with shape info" );
2166-
2150+ // / Lower unary operators.
2151+ void VisitUnaryOperator (UnaryOperator *Inst, const ShapeInfo &SI) {
21672152 Value *Op = Inst->getOperand (0 );
21682153
21692154 IRBuilder<> Builder (Inst);
2170- ShapeInfo &Shape = I->second ;
21712155
21722156 MatrixTy Result;
2173- MatrixTy M = getMatrix (Op, Shape , Builder);
2157+ MatrixTy M = getMatrix (Op, SI , Builder);
21742158
21752159 Builder.setFastMathFlags (getFastMathFlags (Inst));
21762160
@@ -2184,7 +2168,7 @@ class LowerMatrixIntrinsics {
21842168 }
21852169 };
21862170
2187- for (unsigned I = 0 ; I < Shape .getNumVectors (); ++I)
2171+ for (unsigned I = 0 ; I < SI .getNumVectors (); ++I)
21882172 Result.addVector (BuildVectorOp (M.getVector (I)));
21892173
21902174 finalizeLowering (Inst,
0 commit comments