@@ -965,6 +965,28 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
965965 std::function<bool (BitCastOp)> controlFn;
966966};
967967
968+ static bool haveSameShapeAndScaling (Type t, Type u) {
969+ auto tVec = dyn_cast<VectorType>(t);
970+ auto uVec = dyn_cast<VectorType>(u);
971+ if (!tVec) {
972+ return !uVec;
973+ }
974+ if (!uVec) {
975+ return false ;
976+ }
977+ return tVec.getShape () == uVec.getShape () &&
978+ tVec.getScalableDims () == uVec.getScalableDims ();
979+ }
980+
981+ // / If `type` is shaped, clone it with `newElementType`. Otherwise,
982+ // / return `newElementType`.
983+ static Type cloneOrReplace (Type type, Type newElementType) {
984+ if (auto shapedType = dyn_cast<ShapedType>(type)) {
985+ return shapedType.clone (newElementType);
986+ }
987+ return newElementType;
988+ }
989+
968990// / Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
969991// /
970992// / Example:
@@ -988,23 +1010,22 @@ struct ReorderElementwiseOpsOnBroadcast final
9881010 PatternRewriter &rewriter) const override {
9891011 if (op->getNumResults () != 1 )
9901012 return failure ();
991- if (!llvm::isa<ShapedType>(op->getResults ()[0 ].getType ()))
1013+ auto resultType = dyn_cast<VectorType>(op->getResult (0 ).getType ());
1014+ if (!resultType)
9921015 return failure ();
9931016 if (!OpTrait::hasElementwiseMappableTraits (op))
9941017 return rewriter.notifyMatchFailure (
9951018 op, " Op doesn't have ElementwiseMappableTraits" );
9961019 if (op->getNumOperands () == 0 )
9971020 return failure ();
998- if (op->getResults ()[0 ].getType () != op->getOperand (0 ).getType ())
999- return rewriter.notifyMatchFailure (op,
1000- " result and operand type mismatch" );
10011021 if (isa<vector::FMAOp>(op)) {
10021022 return rewriter.notifyMatchFailure (
10031023 op,
10041024 " Op only accepts vector types - not supported as broadcast source "
10051025 " might be a scalar" );
10061026 }
10071027
1028+ Type resultElemType = resultType.getElementType ();
10081029 // Get the type of the first non-constant operand
10091030 Operation *firstBroadcastOrSplat = nullptr ;
10101031 for (Value operand : op->getOperands ()) {
@@ -1020,24 +1041,23 @@ struct ReorderElementwiseOpsOnBroadcast final
10201041 }
10211042 if (!firstBroadcastOrSplat)
10221043 return failure ();
1023- Type firstBroadcastOrSplatType =
1024- firstBroadcastOrSplat->getOperand (0 ).getType ();
1044+ Type unbroadcastResultType = cloneOrReplace (
1045+ firstBroadcastOrSplat->getOperand (0 ).getType (), resultElemType) ;
10251046
1026- // Make sure that all operands are broadcast from identical types:
1047+ // Make sure that all operands are broadcast from identically-shaped types:
10271048 // * scalar (`vector.broadcast` + `vector.splat`), or
10281049 // * vector (`vector.broadcast`).
10291050 // Otherwise the re-ordering wouldn't be safe.
1030- if (!llvm::all_of (
1031- op->getOperands (), [&firstBroadcastOrSplatType](Value val) {
1032- if (auto bcastOp = val.getDefiningOp <vector::BroadcastOp>())
1033- return (bcastOp.getOperand ().getType () ==
1034- firstBroadcastOrSplatType);
1035- if (auto splatOp = val.getDefiningOp <vector::SplatOp>())
1036- return (splatOp.getOperand ().getType () ==
1037- firstBroadcastOrSplatType);
1038- SplatElementsAttr splatConst;
1039- return matchPattern (val, m_Constant (&splatConst));
1040- })) {
1051+ if (!llvm::all_of (op->getOperands (), [&unbroadcastResultType](Value val) {
1052+ if (auto bcastOp = val.getDefiningOp <vector::BroadcastOp>())
1053+ return haveSameShapeAndScaling (bcastOp.getOperand ().getType (),
1054+ unbroadcastResultType);
1055+ if (auto splatOp = val.getDefiningOp <vector::SplatOp>())
1056+ return haveSameShapeAndScaling (splatOp.getOperand ().getType (),
1057+ unbroadcastResultType);
1058+ SplatElementsAttr splatConst;
1059+ return matchPattern (val, m_Constant (&splatConst));
1060+ })) {
10411061 return failure ();
10421062 }
10431063
@@ -1048,15 +1068,16 @@ struct ReorderElementwiseOpsOnBroadcast final
10481068 SplatElementsAttr splatConst;
10491069 if (matchPattern (operand, m_Constant (&splatConst))) {
10501070 Attribute newConst;
1051- if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
1052- newConst = splatConst.resizeSplat (shapedTy);
1071+ Type elementType = getElementTypeOrSelf (operand.getType ());
1072+ Type newType = cloneOrReplace (unbroadcastResultType, elementType);
1073+ if (auto shapedTy = dyn_cast<ShapedType>(unbroadcastResultType)) {
1074+ newConst = splatConst.resizeSplat (cast<ShapedType>(newType));
10531075 } else {
10541076 newConst = splatConst.getSplatValue <Attribute>();
10551077 }
10561078 Operation *newConstOp =
10571079 operand.getDefiningOp ()->getDialect ()->materializeConstant (
1058- rewriter, newConst, firstBroadcastOrSplatType,
1059- operand.getLoc ());
1080+ rewriter, newConst, newType, operand.getLoc ());
10601081 srcValues.push_back (newConstOp->getResult (0 ));
10611082 } else {
10621083 srcValues.push_back (operand.getDefiningOp ()->getOperand (0 ));
@@ -1066,12 +1087,11 @@ struct ReorderElementwiseOpsOnBroadcast final
10661087 // Create the "elementwise" Op
10671088 Operation *elementwiseOp =
10681089 rewriter.create (op->getLoc (), op->getName ().getIdentifier (), srcValues,
1069- firstBroadcastOrSplatType , op->getAttrs ());
1090+ unbroadcastResultType , op->getAttrs ());
10701091
10711092 // Replace the original Op with the elementwise Op
1072- auto vectorType = op->getResultTypes ()[0 ];
10731093 rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
1074- op, vectorType , elementwiseOp->getResults ());
1094+ op, resultType , elementwiseOp->getResults ());
10751095
10761096 return success ();
10771097 }
0 commit comments