@@ -229,14 +229,11 @@ static bool isUniformShape(Value *V) {
229229 if (!I)
230230 return true ;
231231
232+ if (I->isBinaryOp ())
233+ return true ;
234+
232235 switch (I->getOpcode ()) {
233- case Instruction::FAdd:
234- case Instruction::FSub:
235- case Instruction::FMul: // Scalar multiply.
236236 case Instruction::FNeg:
237- case Instruction::Add:
238- case Instruction::Mul:
239- case Instruction::Sub:
240237 return true ;
241238 default :
242239 return false ;
@@ -2154,28 +2151,9 @@ class LowerMatrixIntrinsics {
21542151
21552152 Builder.setFastMathFlags (getFastMathFlags (Inst));
21562153
2157- // Helper to perform binary op on vectors.
2158- auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2159- switch (Inst->getOpcode ()) {
2160- case Instruction::Add:
2161- return Builder.CreateAdd (LHS, RHS);
2162- case Instruction::Mul:
2163- return Builder.CreateMul (LHS, RHS);
2164- case Instruction::Sub:
2165- return Builder.CreateSub (LHS, RHS);
2166- case Instruction::FAdd:
2167- return Builder.CreateFAdd (LHS, RHS);
2168- case Instruction::FMul:
2169- return Builder.CreateFMul (LHS, RHS);
2170- case Instruction::FSub:
2171- return Builder.CreateFSub (LHS, RHS);
2172- default :
2173- llvm_unreachable (" Unsupported binary operator for matrix" );
2174- }
2175- };
2176-
21772154 for (unsigned I = 0 ; I < Shape.getNumVectors (); ++I)
2178- Result.addVector (BuildVectorOp (A.getVector (I), B.getVector (I)));
2155+ Result.addVector (Builder.CreateBinOp (Inst->getOpcode (), A.getVector (I),
2156+ B.getVector (I)));
21792157
21802158 finalizeLowering (Inst,
21812159 Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
0 commit comments