@@ -1056,19 +1056,24 @@ class LowerMatrixIntrinsics {
10561056
10571057 IRBuilder<> Builder (Inst);
10581058
1059+ const ShapeInfo &SI = ShapeMap.at (Inst);
1060+
10591061 if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1060- Changed |= VisitCallInst (CInst);
1062+ Changed |= tryVisitCallInst (CInst);
10611063
10621064 Value *Op1;
10631065 Value *Op2;
1064- if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1065- Changed |= VisitBinaryOperator (BinOp);
1066- if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1067- Changed |= VisitUnaryOperator (UnOp);
10681066 if (match (Inst, m_Load (m_Value (Op1))))
1069- Changed |= VisitLoad (cast<LoadInst>(Inst), Op1, Builder);
1067+ VisitLoad (cast<LoadInst>(Inst), SI , Op1, Builder);
10701068 else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
1071- Changed |= VisitStore (cast<StoreInst>(Inst), Op1, Op2, Builder);
1069+ VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
1070+ else if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1071+ VisitBinaryOperator (BinOp, SI);
1072+ else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1073+ VisitUnaryOperator (UnOp, SI);
1074+ else
1075+ continue ;
1076+ Changed = true ;
10721077 }
10731078
10741079 if (ORE) {
@@ -1107,7 +1112,7 @@ class LowerMatrixIntrinsics {
11071112 }
11081113
11091114 // / Replace intrinsic calls
1110- bool VisitCallInst (CallInst *Inst) {
1115+ bool tryVisitCallInst (CallInst *Inst) {
11111116 if (!Inst->getCalledFunction () || !Inst->getCalledFunction ()->isIntrinsic ())
11121117 return false ;
11131118
@@ -2105,72 +2110,53 @@ class LowerMatrixIntrinsics {
21052110 }
21062111
21072112 // / Lower load instructions, if shape information is available.
2108- bool VisitLoad (LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
2109- auto I = ShapeMap.find (Inst);
2110- assert (I != ShapeMap.end () &&
2111- " must only visit instructions with shape info" );
2113+ void VisitLoad (LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, IRBuilder<> &Builder) {
21122114 LowerLoad (Inst, Ptr, Inst->getAlign (),
2113- Builder.getInt64 (I->second .getStride ()), Inst->isVolatile (),
2114- I->second );
2115- return true ;
2115+ Builder.getInt64 (SI.getStride ()), Inst->isVolatile (),
2116+ SI);
21162117 }
21172118
2118- bool VisitStore (StoreInst *Inst, Value *StoredVal, Value *Ptr,
2119+ void VisitStore (StoreInst *Inst, const ShapeInfo &SI , Value *StoredVal, Value *Ptr,
21192120 IRBuilder<> &Builder) {
2120- auto I = ShapeMap.find (StoredVal);
2121- assert (I != ShapeMap.end () &&
2122- " must only visit instructions with shape info" );
21232121 LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2124- Builder.getInt64 (I->second .getStride ()), Inst->isVolatile (),
2125- I->second );
2126- return true ;
2122+ Builder.getInt64 (SI.getStride ()), Inst->isVolatile (),
2123+ SI);
21272124 }
21282125
21292126 // / Lower binary operators, if shape information is available.
2130- bool VisitBinaryOperator (BinaryOperator *Inst) {
2131- auto I = ShapeMap.find (Inst);
2132- assert (I != ShapeMap.end () &&
2133- " must only visit instructions with shape info" );
2134-
2127+ void VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI) {
21352128 Value *Lhs = Inst->getOperand (0 );
21362129 Value *Rhs = Inst->getOperand (1 );
21372130
21382131 IRBuilder<> Builder (Inst);
2139- ShapeInfo &Shape = I->second ;
21402132
21412133 MatrixTy Result;
2142- MatrixTy A = getMatrix (Lhs, Shape , Builder);
2143- MatrixTy B = getMatrix (Rhs, Shape , Builder);
2134+ MatrixTy A = getMatrix (Lhs, SI , Builder);
2135+ MatrixTy B = getMatrix (Rhs, SI , Builder);
21442136 assert (A.isColumnMajor () == B.isColumnMajor () &&
21452137 Result.isColumnMajor () == A.isColumnMajor () &&
21462138 " operands must agree on matrix layout" );
21472139
21482140 Builder.setFastMathFlags (getFastMathFlags (Inst));
21492141
2150- for (unsigned I = 0 ; I < Shape .getNumVectors (); ++I)
2142+ for (unsigned I = 0 ; I < SI .getNumVectors (); ++I)
21512143 Result.addVector (Builder.CreateBinOp (Inst->getOpcode (), A.getVector (I),
21522144 B.getVector (I)));
21532145
21542146 finalizeLowering (Inst,
21552147 Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
21562148 Result.getNumVectors ()),
21572149 Builder);
2158- return true ;
21592150 }
21602151
21612152 // / Lower unary operators, if shape information is available.
2162- bool VisitUnaryOperator (UnaryOperator *Inst) {
2163- auto I = ShapeMap.find (Inst);
2164- assert (I != ShapeMap.end () &&
2165- " must only visit instructions with shape info" );
2166-
2153+ void VisitUnaryOperator (UnaryOperator *Inst, const ShapeInfo &SI) {
21672154 Value *Op = Inst->getOperand (0 );
21682155
21692156 IRBuilder<> Builder (Inst);
2170- ShapeInfo &Shape = I->second ;
21712157
21722158 MatrixTy Result;
2173- MatrixTy M = getMatrix (Op, Shape , Builder);
2159+ MatrixTy M = getMatrix (Op, SI , Builder);
21742160
21752161 Builder.setFastMathFlags (getFastMathFlags (Inst));
21762162
@@ -2184,14 +2170,13 @@ class LowerMatrixIntrinsics {
21842170 }
21852171 };
21862172
2187- for (unsigned I = 0 ; I < Shape .getNumVectors (); ++I)
2173+ for (unsigned I = 0 ; I < SI .getNumVectors (); ++I)
21882174 Result.addVector (BuildVectorOp (M.getVector (I)));
21892175
21902176 finalizeLowering (Inst,
21912177 Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
21922178 Result.getNumVectors ()),
21932179 Builder);
2194- return true ;
21952180 }
21962181
21972182 // / Helper to linearize a matrix expression tree into a string. Currently
0 commit comments