Skip to content

Commit 8fbd94d

Browse files
authored
Transpose-like slice and elemwise (#2033)
* Transpose-like slice and elemwise * more * fix * fmt * fix * more * more fix * fix * fix * fix
1 parent 4bcd862 commit 8fbd94d

19 files changed

+549
-217
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 229 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2869,6 +2869,79 @@ struct LICMElementwise
28692869
}
28702870
};
28712871

2872+
static bool isIotaRange(ArrayRef<int64_t> dims) {
2873+
return llvm::all_of(llvm::enumerate(dims), [](const auto &it) {
2874+
return static_cast<int64_t>(it.index()) == it.value();
2875+
});
2876+
}
2877+
2878+
/// Matches when either of the submatchers match.
2879+
template <typename MatcherA, typename MatcherB> struct m_AnyOf {
2880+
m_AnyOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {}
2881+
2882+
bool match(Operation *op) { return matcherA.match(op) || matcherB.match(op); }
2883+
2884+
MatcherA matcherA;
2885+
MatcherB matcherB;
2886+
};
2887+
2888+
template <typename MatcherA, typename MatcherB>
2889+
m_AnyOf(MatcherA, MatcherB) -> m_AnyOf<MatcherA, MatcherB>;
2890+
2891+
/// Binary constant folder that used a generic folder function to handle both
2892+
/// ints and floats.
2893+
template <typename Fn>
2894+
static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs,
2895+
Fn &&folder) {
2896+
Attribute operands[2] = {lhs, rhs};
2897+
Type elemTy = getElementTypeOrSelf(lhs);
2898+
2899+
Attribute res;
2900+
if (isa<IntegerType>(elemTy))
2901+
res = constFoldBinaryOp<IntegerAttr, IntegerAttr::ValueType, void>(operands,
2902+
folder);
2903+
if (isa<FloatType>(elemTy))
2904+
res = constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void>(operands,
2905+
folder);
2906+
if (res)
2907+
return cast<TypedAttr>(res);
2908+
2909+
return nullptr;
2910+
}
2911+
2912+
static bool isTransposeReshapeLikeBroadcast(stablehlo::BroadcastInDimOp op) {
2913+
TypedValue<RankedTensorType> operand = op.getOperand();
2914+
RankedTensorType operandTy = operand.getType();
2915+
RankedTensorType type = op.getType();
2916+
2917+
// Fold when broadcast is a noop.
2918+
auto dims = op.getBroadcastDimensions();
2919+
bool isDimsIota = isIotaRange(dims);
2920+
if (type == operandTy && isDimsIota) {
2921+
return false;
2922+
}
2923+
2924+
// Handle splat broadcasts.
2925+
if (SplatElementsAttr cstAttr; matchPattern(operand, m_Constant(&cstAttr))) {
2926+
return false;
2927+
}
2928+
2929+
if (operandTy.hasStaticShape() && type.hasStaticShape() &&
2930+
type.getNumElements() == operandTy.getNumElements()) {
2931+
// BroadcastInDim equivalent to reshape.
2932+
if (isDimsIota) {
2933+
return false;
2934+
}
2935+
// BroadcastInDim equivalent to transpose.
2936+
if (type.getRank() == operandTy.getRank()) {
2937+
return false;
2938+
}
2939+
2940+
return true;
2941+
}
2942+
return false;
2943+
}
2944+
28722945
// slice(broadcast x) -> broadcast(slice x)
28732946
struct SliceBroadcast final
28742947
: CheckedOpRewritePattern<stablehlo::SliceOp, SliceBroadcast> {
@@ -2884,6 +2957,9 @@ struct SliceBroadcast final
28842957
if (!bcast)
28852958
return failure();
28862959

2960+
if (isTransposeReshapeLikeBroadcast(bcast))
2961+
return failure();
2962+
28872963
SmallVector<int64_t> nbcast_idx;
28882964

28892965
auto preShape = cast<RankedTensorType>(bcast.getOperand().getType());
@@ -3138,6 +3214,52 @@ struct TransposeSliceBase final
31383214
}
31393215
};
31403216

3217+
// transpose(dynamic_slice x) -> dynamic_slice(transpose x)
3218+
// transpose(slice x) -> slice(transpose x)
3219+
template <typename OpTy>
3220+
struct TransposeLikeBroadcastSliceBase final
3221+
: CheckedOpRewritePattern<stablehlo::BroadcastInDimOp,
3222+
TransposeLikeBroadcastSliceBase<OpTy>> {
3223+
using CheckedOpRewritePattern<
3224+
stablehlo::BroadcastInDimOp,
3225+
TransposeLikeBroadcastSliceBase<OpTy>>::CheckedOpRewritePattern;
3226+
3227+
LogicalResult matchAndRewriteImpl(stablehlo::BroadcastInDimOp op,
3228+
PatternRewriter &rewriter) const {
3229+
if (!isTransposeReshapeLikeBroadcast(op))
3230+
return failure();
3231+
auto sliceOp = op.getOperand().template getDefiningOp<OpTy>();
3232+
if (!sliceOp) {
3233+
return failure();
3234+
}
3235+
3236+
// If we can fuse the transpose into all of its users then we shouldn't push
3237+
// it up (or atleast give higher priority to other passes before trying to
3238+
// move it up)
3239+
if (llvm::all_of(op->getUsers(),
3240+
[&](auto user) { return isFusible(op, user); })) {
3241+
return failure();
3242+
}
3243+
3244+
bool singleUser = sliceOp->getResult(0).hasOneUse();
3245+
3246+
SmallVector<int64_t> shape = llvm::to_vector(op.getType().getShape());
3247+
for (auto &&[i, v] : llvm::enumerate(op.getBroadcastDimensions())) {
3248+
shape[v] = sliceOp.getOperand().getType().getShape()[i];
3249+
}
3250+
auto newTranspose = stablehlo::BroadcastInDimOp::create(
3251+
rewriter, op.getLoc(),
3252+
RankedTensorType::get(shape, op.getType().getElementType()),
3253+
sliceOp.getOperand(), op.getBroadcastDimensions());
3254+
auto newSlice = transposeLikeSliceHelper(newTranspose, rewriter, sliceOp);
3255+
rewriter.replaceOp(op, newSlice);
3256+
if (singleUser) {
3257+
rewriter.eraseOp(sliceOp);
3258+
}
3259+
return success();
3260+
}
3261+
};
3262+
31413263
struct TransposeAllUsersSlice final
31423264
: public CheckedOpRewritePattern<stablehlo::TransposeOp,
31433265
TransposeAllUsersSlice> {
@@ -6553,6 +6675,46 @@ struct BroadcastReshape final
65536675
if (!type)
65546676
return failure();
65556677

6678+
if (reshape.getType().hasStaticShape() && type.hasStaticShape() &&
6679+
type.getShape().size() >=
6680+
reshape.getOperand().getType().getShape().size()) {
6681+
6682+
auto deletionDims = findReshapeInsertionDims(
6683+
reshape.getType(), reshape.getOperand().getType());
6684+
if (!deletionDims.empty()) {
6685+
SmallVector<int64_t> unusedSingletons;
6686+
6687+
for (auto &&[i, s] : llvm::enumerate(type.getShape())) {
6688+
if (llvm::is_contained(op.getBroadcastDimensions(), i))
6689+
continue;
6690+
unusedSingletons.push_back(i);
6691+
}
6692+
6693+
SmallVector<int64_t> bcast(
6694+
reshape.getOperand().getType().getShape().size(), 0);
6695+
6696+
size_t bcast_idx = 0;
6697+
size_t unused_idx = 0;
6698+
6699+
for (size_t i = 0; i < reshape.getOperand().getType().getShape().size();
6700+
i++) {
6701+
if (llvm::is_contained(deletionDims, i)) {
6702+
bcast[i] = unusedSingletons[unused_idx];
6703+
unused_idx++;
6704+
6705+
} else {
6706+
bcast[i] = op.getBroadcastDimensions()[bcast_idx];
6707+
bcast_idx++;
6708+
}
6709+
}
6710+
6711+
rewriter.replaceOpWithNewOp<stablehlo::BroadcastInDimOp>(
6712+
op, op.getType(), reshape.getOperand(), bcast);
6713+
6714+
return success();
6715+
}
6716+
}
6717+
65566718
SmallVector<int64_t> dims;
65576719

65586720
size_t pre_reshape_idx = 0;
@@ -11061,6 +11223,64 @@ struct SliceReshapeElementwise final
1106111223
}
1106211224
};
1106311225

11226+
struct TransposeLikeBroadcastElementwise final
11227+
: CheckedOpRewritePattern<stablehlo::BroadcastInDimOp,
11228+
TransposeLikeBroadcastElementwise> {
11229+
using CheckedOpRewritePattern::CheckedOpRewritePattern;
11230+
11231+
bool onlySingleUser;
11232+
11233+
TransposeLikeBroadcastElementwise(bool onlySingleUser, MLIRContext *context,
11234+
PatternBenefit benefit = 1,
11235+
ArrayRef<StringRef> generatedNames = {})
11236+
: CheckedOpRewritePattern(context, benefit, generatedNames),
11237+
onlySingleUser(onlySingleUser) {}
11238+
11239+
LogicalResult matchAndRewriteImpl(stablehlo::BroadcastInDimOp op,
11240+
PatternRewriter &rewriter) const {
11241+
if (!isTransposeReshapeLikeBroadcast(op))
11242+
return failure();
11243+
auto elem = op.getOperand().getDefiningOp();
11244+
if (!elem)
11245+
return failure();
11246+
11247+
if (!stablehlo::hasTraitElementwise(elem))
11248+
return failure();
11249+
11250+
bool singleUser = llvm::hasSingleElement(elem->getUsers());
11251+
if (onlySingleUser && !singleUser)
11252+
return failure();
11253+
11254+
SmallVector<Value> ops;
11255+
for (auto v : elem->getOperands()) {
11256+
if (auto rop = v.getDefiningOp()) {
11257+
rewriter.setInsertionPointAfter(rop);
11258+
} else if (auto ba = dyn_cast<BlockArgument>(v)) {
11259+
rewriter.setInsertionPointToStart(ba.getOwner());
11260+
}
11261+
auto nt = RankedTensorType::get(
11262+
op.getType().getShape(),
11263+
cast<RankedTensorType>(v.getType()).getElementType());
11264+
ops.push_back(stablehlo::BroadcastInDimOp::create(
11265+
rewriter, op.getLoc(), nt, v, op.getBroadcastDimensions()));
11266+
}
11267+
if (singleUser) {
11268+
rewriter.modifyOpInPlace(elem, [&]() {
11269+
elem->setOperands(ops);
11270+
elem->getResult(0).setType(op.getType());
11271+
});
11272+
rewriter.replaceOp(op, elem);
11273+
} else {
11274+
rewriter.setInsertionPointAfter(elem);
11275+
auto newOp = rewriter.create(
11276+
elem->getLoc(), elem->getName().getIdentifier(), ValueRange(ops),
11277+
TypeRange(op.getType()), elem->getAttrs(), {}, {});
11278+
rewriter.replaceOp(op, newOp);
11279+
}
11280+
return success();
11281+
}
11282+
};
11283+
1106411284
struct TransposeElementwise final
1106511285
: CheckedOpRewritePattern<stablehlo::TransposeOp, TransposeElementwise> {
1106611286
using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -12519,47 +12739,6 @@ struct DivideDivideSimplify
1251912739
}
1252012740
};
1252112741

12522-
//////////////// Imported from stablehlo
12523-
static bool isIotaRange(ArrayRef<int64_t> dims) {
12524-
return llvm::all_of(llvm::enumerate(dims), [](const auto &it) {
12525-
return static_cast<int64_t>(it.index()) == it.value();
12526-
});
12527-
}
12528-
12529-
/// Matches when either of the submatchers match.
12530-
template <typename MatcherA, typename MatcherB> struct m_AnyOf {
12531-
m_AnyOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {}
12532-
12533-
bool match(Operation *op) { return matcherA.match(op) || matcherB.match(op); }
12534-
12535-
MatcherA matcherA;
12536-
MatcherB matcherB;
12537-
};
12538-
12539-
template <typename MatcherA, typename MatcherB>
12540-
m_AnyOf(MatcherA, MatcherB) -> m_AnyOf<MatcherA, MatcherB>;
12541-
12542-
/// Binary constant folder that used a generic folder function to handle both
12543-
/// ints and floats.
12544-
template <typename Fn>
12545-
static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs,
12546-
Fn &&folder) {
12547-
Attribute operands[2] = {lhs, rhs};
12548-
Type elemTy = getElementTypeOrSelf(lhs);
12549-
12550-
Attribute res;
12551-
if (isa<IntegerType>(elemTy))
12552-
res = constFoldBinaryOp<IntegerAttr, IntegerAttr::ValueType, void>(operands,
12553-
folder);
12554-
if (isa<FloatType>(elemTy))
12555-
res = constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void>(operands,
12556-
folder);
12557-
if (res)
12558-
return cast<TypedAttr>(res);
12559-
12560-
return nullptr;
12561-
}
12562-
1256312742
static stablehlo::ComparisonDirection
1256412743
invertDirection(stablehlo::ComparisonDirection direction) {
1256512744
using stablehlo::ComparisonDirection;
@@ -31843,6 +32022,13 @@ void mlir::transform::addTransposeElementwise(RewritePatternSet &patterns,
3184332022
patterns.insert<TransposeElementwise>(onlySingleUser, &context, benefit);
3184432023
}
3184532024

32025+
void mlir::transform::addTransposeLikeBroadcastElementwise(
32026+
RewritePatternSet &patterns, bool onlySingleUser, MLIRContext &context,
32027+
PatternBenefit benefit) {
32028+
patterns.insert<TransposeLikeBroadcastElementwise>(onlySingleUser, &context,
32029+
benefit);
32030+
}
32031+
3184632032
void mlir::transform::addReshapeElementwise(RewritePatternSet &patterns,
3184732033
bool onlySingleUser,
3184832034
MLIRContext &context,
@@ -32095,9 +32281,11 @@ struct EnzymeHLOOptPass
3209532281

3209632282
if (passses & (2048 * 32)) {
3209732283
patterns.add<TransposeWhile, TransposeSliceBase<stablehlo::SliceOp>,
32284+
TransposeLikeBroadcastSliceBase<stablehlo::SliceOp>,
3209832285
TransposeConcat, TransposeDUS, TransposeIota,
3209932286
TransposeReduceWindow, TransposeReduce, TransposeSelect,
3210032287
TransposeSliceBase<stablehlo::DynamicSliceOp>,
32288+
TransposeLikeBroadcastSliceBase<stablehlo::DynamicSliceOp>,
3210132289
TransposeReverse, TransposeBatchNormTraining,
3210232290
TransposeBatchNormInference, TransposeBatchNormGrad,
3210332291
TransposeIf, TransposeFFT, TransposeReshape>(context);

src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ void addConcatenateOpCanon(RewritePatternSet &patterns,
109109
PatternBenefit benefit);
110110
void addTransposeElementwise(RewritePatternSet &patterns, bool onlySingleUser,
111111
MLIRContext &context, PatternBenefit benefit);
112+
void addTransposeLikeBroadcastElementwise(RewritePatternSet &patterns,
113+
bool onlySingleUser,
114+
MLIRContext &context,
115+
PatternBenefit benefit);
112116
void addReshapeElementwise(RewritePatternSet &patterns, bool onlySingleUser,
113117
MLIRContext &context, PatternBenefit benefit);
114118
void addReshapeElementwiseOnlyFusible(RewritePatternSet &patterns,

src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -485,10 +485,11 @@ struct SyrkOpLowering : public OpRewritePattern<enzymexla::SyrkOp> {
485485
rewriter, op.getLoc(), cast<RankedTensorType>(op.getC().getType()),
486486
op.getA(), op.getA(), dotDims, nullptr, nullptr);
487487

488-
auto res = stablehlo::AddOpCreate(
489-
rewriter, op->getLoc(),
490-
stablehlo::MulOpCreate(rewriter, op->getLoc(), op.getAlpha(), AAT),
491-
stablehlo::MulOpCreate(rewriter, op->getLoc(), op.getBeta(), C));
488+
auto aop =
489+
stablehlo::MulOpCreate(rewriter, op->getLoc(), op.getAlpha(), AAT);
490+
auto bop = stablehlo::MulOpCreate(rewriter, op->getLoc(), op.getBeta(), C);
491+
492+
auto res = stablehlo::AddOpCreate(rewriter, op->getLoc(), aop, bop);
492493
rewriter.replaceOp(op, res);
493494
return success();
494495
}

src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ LogicalResult parseTransform(OpBuilder &builder, Location loc,
160160
opName == "dynamic_slice_licm" || opName == "scatter_licm" ||
161161
opName == "gather_licm" || opName == "iota_licm" ||
162162
opName == "transpose_elementwise" ||
163+
opName == "transpose_like_broadcast_elementwise" ||
163164
opName == "reshape_elementwise" ||
164165
opName == "reshape_elementwise_only_fusible" ||
165166
opName == "reshape_slice" || opName == "reshape_dynamic_slice" ||

src/enzyme_ad/jax/TransformOps/TransformOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,12 @@ void ApplyTransposeElementwisePatterns::populatePatterns(
209209
addTransposeElementwise(patterns, getParameter(), *getContext(),
210210
PatternBenefit(getBenefit().value_or(1)));
211211
}
212+
void ApplyTransposeLikeBroadcastElementwisePatterns::populatePatterns(
213+
RewritePatternSet &patterns) {
214+
addTransposeLikeBroadcastElementwise(
215+
patterns, getParameter(), *getContext(),
216+
PatternBenefit(getBenefit().value_or(1)));
217+
}
212218
void ApplyReshapeElementwisePatterns::populatePatterns(
213219
RewritePatternSet &patterns) {
214220
addReshapeElementwise(patterns, getParameter(), *getContext(),

0 commit comments

Comments
 (0)