@@ -2869,6 +2869,79 @@ struct LICMElementwise
28692869 }
28702870};
28712871
2872+ static bool isIotaRange(ArrayRef<int64_t> dims) {
2873+ return llvm::all_of(llvm::enumerate(dims), [](const auto &it) {
2874+ return static_cast<int64_t>(it.index()) == it.value();
2875+ });
2876+ }
2877+
2878+ /// Matches when either of the submatchers match.
2879+ template <typename MatcherA, typename MatcherB> struct m_AnyOf {
2880+ m_AnyOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {}
2881+
2882+ bool match(Operation *op) { return matcherA.match(op) || matcherB.match(op); }
2883+
2884+ MatcherA matcherA;
2885+ MatcherB matcherB;
2886+ };
2887+
2888+ template <typename MatcherA, typename MatcherB>
2889+ m_AnyOf(MatcherA, MatcherB) -> m_AnyOf<MatcherA, MatcherB>;
2890+
2891+ /// Binary constant folder that used a generic folder function to handle both
2892+ /// ints and floats.
2893+ template <typename Fn>
2894+ static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs,
2895+ Fn &&folder) {
2896+ Attribute operands[2] = {lhs, rhs};
2897+ Type elemTy = getElementTypeOrSelf(lhs);
2898+
2899+ Attribute res;
2900+ if (isa<IntegerType>(elemTy))
2901+ res = constFoldBinaryOp<IntegerAttr, IntegerAttr::ValueType, void>(operands,
2902+ folder);
2903+ if (isa<FloatType>(elemTy))
2904+ res = constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void>(operands,
2905+ folder);
2906+ if (res)
2907+ return cast<TypedAttr>(res);
2908+
2909+ return nullptr;
2910+ }
2911+
2912+ static bool isTransposeReshapeLikeBroadcast(stablehlo::BroadcastInDimOp op) {
2913+ TypedValue<RankedTensorType> operand = op.getOperand();
2914+ RankedTensorType operandTy = operand.getType();
2915+ RankedTensorType type = op.getType();
2916+
2917+ // Fold when broadcast is a noop.
2918+ auto dims = op.getBroadcastDimensions();
2919+ bool isDimsIota = isIotaRange(dims);
2920+ if (type == operandTy && isDimsIota) {
2921+ return false;
2922+ }
2923+
2924+ // Handle splat broadcasts.
2925+ if (SplatElementsAttr cstAttr; matchPattern(operand, m_Constant(&cstAttr))) {
2926+ return false;
2927+ }
2928+
2929+ if (operandTy.hasStaticShape() && type.hasStaticShape() &&
2930+ type.getNumElements() == operandTy.getNumElements()) {
2931+ // BroadcastInDim equivalent to reshape.
2932+ if (isDimsIota) {
2933+ return false;
2934+ }
2935+ // BroadcastInDim equivalent to transpose.
2936+ if (type.getRank() == operandTy.getRank()) {
2937+ return false;
2938+ }
2939+
2940+ return true;
2941+ }
2942+ return false;
2943+ }
2944+
28722945// slice(broadcast x) -> broadcast(slice x)
28732946struct SliceBroadcast final
28742947 : CheckedOpRewritePattern<stablehlo::SliceOp, SliceBroadcast> {
@@ -2884,6 +2957,9 @@ struct SliceBroadcast final
28842957 if (!bcast)
28852958 return failure();
28862959
2960+ if (isTransposeReshapeLikeBroadcast(bcast))
2961+ return failure();
2962+
28872963 SmallVector<int64_t> nbcast_idx;
28882964
28892965 auto preShape = cast<RankedTensorType>(bcast.getOperand().getType());
@@ -3138,6 +3214,52 @@ struct TransposeSliceBase final
31383214 }
31393215};
31403216
3217+ // transpose(dynamic_slice x) -> dynamic_slice(transpose x)
3218+ // transpose(slice x) -> slice(transpose x)
3219+ template <typename OpTy>
3220+ struct TransposeLikeBroadcastSliceBase final
3221+ : CheckedOpRewritePattern<stablehlo::BroadcastInDimOp,
3222+ TransposeLikeBroadcastSliceBase<OpTy>> {
3223+ using CheckedOpRewritePattern<
3224+ stablehlo::BroadcastInDimOp,
3225+ TransposeLikeBroadcastSliceBase<OpTy>>::CheckedOpRewritePattern;
3226+
3227+ LogicalResult matchAndRewriteImpl(stablehlo::BroadcastInDimOp op,
3228+ PatternRewriter &rewriter) const {
3229+ if (!isTransposeReshapeLikeBroadcast(op))
3230+ return failure();
3231+ auto sliceOp = op.getOperand().template getDefiningOp<OpTy>();
3232+ if (!sliceOp) {
3233+ return failure();
3234+ }
3235+
3236+ // If we can fuse the transpose into all of its users then we shouldn't push
3237+ // it up (or atleast give higher priority to other passes before trying to
3238+ // move it up)
3239+ if (llvm::all_of(op->getUsers(),
3240+ [&](auto user) { return isFusible(op, user); })) {
3241+ return failure();
3242+ }
3243+
3244+ bool singleUser = sliceOp->getResult(0).hasOneUse();
3245+
3246+ SmallVector<int64_t> shape = llvm::to_vector(op.getType().getShape());
3247+ for (auto &&[i, v] : llvm::enumerate(op.getBroadcastDimensions())) {
3248+ shape[v] = sliceOp.getOperand().getType().getShape()[i];
3249+ }
3250+ auto newTranspose = stablehlo::BroadcastInDimOp::create(
3251+ rewriter, op.getLoc(),
3252+ RankedTensorType::get(shape, op.getType().getElementType()),
3253+ sliceOp.getOperand(), op.getBroadcastDimensions());
3254+ auto newSlice = transposeLikeSliceHelper(newTranspose, rewriter, sliceOp);
3255+ rewriter.replaceOp(op, newSlice);
3256+ if (singleUser) {
3257+ rewriter.eraseOp(sliceOp);
3258+ }
3259+ return success();
3260+ }
3261+ };
3262+
31413263struct TransposeAllUsersSlice final
31423264 : public CheckedOpRewritePattern<stablehlo::TransposeOp,
31433265 TransposeAllUsersSlice> {
@@ -6553,6 +6675,46 @@ struct BroadcastReshape final
65536675 if (!type)
65546676 return failure();
65556677
6678+ if (reshape.getType().hasStaticShape() && type.hasStaticShape() &&
6679+ type.getShape().size() >=
6680+ reshape.getOperand().getType().getShape().size()) {
6681+
6682+ auto deletionDims = findReshapeInsertionDims(
6683+ reshape.getType(), reshape.getOperand().getType());
6684+ if (!deletionDims.empty()) {
6685+ SmallVector<int64_t> unusedSingletons;
6686+
6687+ for (auto &&[i, s] : llvm::enumerate(type.getShape())) {
6688+ if (llvm::is_contained(op.getBroadcastDimensions(), i))
6689+ continue;
6690+ unusedSingletons.push_back(i);
6691+ }
6692+
6693+ SmallVector<int64_t> bcast(
6694+ reshape.getOperand().getType().getShape().size(), 0);
6695+
6696+ size_t bcast_idx = 0;
6697+ size_t unused_idx = 0;
6698+
6699+ for (size_t i = 0; i < reshape.getOperand().getType().getShape().size();
6700+ i++) {
6701+ if (llvm::is_contained(deletionDims, i)) {
6702+ bcast[i] = unusedSingletons[unused_idx];
6703+ unused_idx++;
6704+
6705+ } else {
6706+ bcast[i] = op.getBroadcastDimensions()[bcast_idx];
6707+ bcast_idx++;
6708+ }
6709+ }
6710+
6711+ rewriter.replaceOpWithNewOp<stablehlo::BroadcastInDimOp>(
6712+ op, op.getType(), reshape.getOperand(), bcast);
6713+
6714+ return success();
6715+ }
6716+ }
6717+
65566718 SmallVector<int64_t> dims;
65576719
65586720 size_t pre_reshape_idx = 0;
@@ -11061,6 +11223,64 @@ struct SliceReshapeElementwise final
1106111223 }
1106211224};
1106311225
11226+ struct TransposeLikeBroadcastElementwise final
11227+ : CheckedOpRewritePattern<stablehlo::BroadcastInDimOp,
11228+ TransposeLikeBroadcastElementwise> {
11229+ using CheckedOpRewritePattern::CheckedOpRewritePattern;
11230+
11231+ bool onlySingleUser;
11232+
11233+ TransposeLikeBroadcastElementwise(bool onlySingleUser, MLIRContext *context,
11234+ PatternBenefit benefit = 1,
11235+ ArrayRef<StringRef> generatedNames = {})
11236+ : CheckedOpRewritePattern(context, benefit, generatedNames),
11237+ onlySingleUser(onlySingleUser) {}
11238+
11239+ LogicalResult matchAndRewriteImpl(stablehlo::BroadcastInDimOp op,
11240+ PatternRewriter &rewriter) const {
11241+ if (!isTransposeReshapeLikeBroadcast(op))
11242+ return failure();
11243+ auto elem = op.getOperand().getDefiningOp();
11244+ if (!elem)
11245+ return failure();
11246+
11247+ if (!stablehlo::hasTraitElementwise(elem))
11248+ return failure();
11249+
11250+ bool singleUser = llvm::hasSingleElement(elem->getUsers());
11251+ if (onlySingleUser && !singleUser)
11252+ return failure();
11253+
11254+ SmallVector<Value> ops;
11255+ for (auto v : elem->getOperands()) {
11256+ if (auto rop = v.getDefiningOp()) {
11257+ rewriter.setInsertionPointAfter(rop);
11258+ } else if (auto ba = dyn_cast<BlockArgument>(v)) {
11259+ rewriter.setInsertionPointToStart(ba.getOwner());
11260+ }
11261+ auto nt = RankedTensorType::get(
11262+ op.getType().getShape(),
11263+ cast<RankedTensorType>(v.getType()).getElementType());
11264+ ops.push_back(stablehlo::BroadcastInDimOp::create(
11265+ rewriter, op.getLoc(), nt, v, op.getBroadcastDimensions()));
11266+ }
11267+ if (singleUser) {
11268+ rewriter.modifyOpInPlace(elem, [&]() {
11269+ elem->setOperands(ops);
11270+ elem->getResult(0).setType(op.getType());
11271+ });
11272+ rewriter.replaceOp(op, elem);
11273+ } else {
11274+ rewriter.setInsertionPointAfter(elem);
11275+ auto newOp = rewriter.create(
11276+ elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops),
11277+ TypeRange(op.getType()), elem->getAttrs(), {}, {});
11278+ rewriter.replaceOp(op, newOp);
11279+ }
11280+ return success();
11281+ }
11282+ };
11283+
1106411284struct TransposeElementwise final
1106511285 : CheckedOpRewritePattern<stablehlo::TransposeOp, TransposeElementwise> {
1106611286 using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -12519,47 +12739,6 @@ struct DivideDivideSimplify
1251912739 }
1252012740};
1252112741
12522- //////////////// Imported from stablehlo
12523- static bool isIotaRange(ArrayRef<int64_t> dims) {
12524- return llvm::all_of(llvm::enumerate(dims), [](const auto &it) {
12525- return static_cast<int64_t>(it.index()) == it.value();
12526- });
12527- }
12528-
12529- /// Matches when either of the submatchers match.
12530- template <typename MatcherA, typename MatcherB> struct m_AnyOf {
12531- m_AnyOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {}
12532-
12533- bool match(Operation *op) { return matcherA.match(op) || matcherB.match(op); }
12534-
12535- MatcherA matcherA;
12536- MatcherB matcherB;
12537- };
12538-
12539- template <typename MatcherA, typename MatcherB>
12540- m_AnyOf(MatcherA, MatcherB) -> m_AnyOf<MatcherA, MatcherB>;
12541-
12542- /// Binary constant folder that used a generic folder function to handle both
12543- /// ints and floats.
12544- template <typename Fn>
12545- static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs,
12546- Fn &&folder) {
12547- Attribute operands[2] = {lhs, rhs};
12548- Type elemTy = getElementTypeOrSelf(lhs);
12549-
12550- Attribute res;
12551- if (isa<IntegerType>(elemTy))
12552- res = constFoldBinaryOp<IntegerAttr, IntegerAttr::ValueType, void>(operands,
12553- folder);
12554- if (isa<FloatType>(elemTy))
12555- res = constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void>(operands,
12556- folder);
12557- if (res)
12558- return cast<TypedAttr>(res);
12559-
12560- return nullptr;
12561- }
12562-
1256312742static stablehlo::ComparisonDirection
1256412743invertDirection(stablehlo::ComparisonDirection direction) {
1256512744 using stablehlo::ComparisonDirection;
@@ -31843,6 +32022,13 @@ void mlir::transform::addTransposeElementwise(RewritePatternSet &patterns,
3184332022 patterns.insert<TransposeElementwise>(onlySingleUser, &context, benefit);
3184432023}
3184532024
32025+ void mlir::transform::addTransposeLikeBroadcastElementwise(
32026+ RewritePatternSet &patterns, bool onlySingleUser, MLIRContext &context,
32027+ PatternBenefit benefit) {
32028+ patterns.insert<TransposeLikeBroadcastElementwise>(onlySingleUser, &context,
32029+ benefit);
32030+ }
32031+
3184632032void mlir::transform::addReshapeElementwise(RewritePatternSet &patterns,
3184732033 bool onlySingleUser,
3184832034 MLIRContext &context,
@@ -32095,9 +32281,11 @@ struct EnzymeHLOOptPass
3209532281
3209632282 if (passses & (2048 * 32)) {
3209732283 patterns.add<TransposeWhile, TransposeSliceBase<stablehlo::SliceOp>,
32284+ TransposeLikeBroadcastSliceBase<stablehlo::SliceOp>,
3209832285 TransposeConcat, TransposeDUS, TransposeIota,
3209932286 TransposeReduceWindow, TransposeReduce, TransposeSelect,
3210032287 TransposeSliceBase<stablehlo::DynamicSliceOp>,
32288+ TransposeLikeBroadcastSliceBase<stablehlo::DynamicSliceOp>,
3210132289 TransposeReverse, TransposeBatchNormTraining,
3210232290 TransposeBatchNormInference, TransposeBatchNormGrad,
3210332291 TransposeIf, TransposeFFT, TransposeReshape>(context);
0 commit comments