Skip to content

Commit d960d58

Browse files
committed
conjoined triangle of success
1 parent 6134cf9 commit d960d58

File tree

2 files changed

+32
-163
lines changed

2 files changed

+32
-163
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 7 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -3683,7 +3683,10 @@ struct SliceElementwise final
36833683
return failure();
36843684
if (!stablehlo::hasTraitElementwise(elem))
36853685
return failure();
3686-
if (llvm::hasSingleElement(elem->getUsers())) {
3686+
if (llvm::hasSingleElement(elem->getUsers()) ||
3687+
(llvm::hasSingleElement(op.getResult().getUsers()) &&
3688+
isa<stablehlo::ConcatenateOp>(
3689+
op.getResult().use_begin()->getOwner()))) {
36873690
SmallVector<Value> ops;
36883691
for (auto v : elem->getOperands()) {
36893692
ops.push_back(stablehlo::SliceOp::create(
@@ -13517,8 +13520,7 @@ struct SelectBroadcastIota final
1351713520
return failure();
1351813521

1351913522
// broadcast input must be a compare
13520-
auto compare =
13521-
broadcast.getOperand().getDefiningOp<stablehlo::CompareOp>();
13523+
auto compare = broadcast.getOperand().getDefiningOp<stablehlo::CompareOp>();
1352213524
if (!compare)
1352313525
return failure();
1352413526

@@ -13639,9 +13641,8 @@ struct SelectBroadcastIota final
1363913641
break;
1364013642
}
1364113643

13642-
auto validCount =
13643-
std::count_if(slices.begin(), slices.end(),
13644-
[](slice_data d) { return d.count > 0; });
13644+
auto validCount = std::count_if(slices.begin(), slices.end(),
13645+
[](slice_data d) { return d.count > 0; });
1364513646
if (validCount == 1) {
1364613647
for (auto &e : slices)
1364713648
if (e.count > 0) {
@@ -28417,163 +28418,6 @@ struct SubtractMultiplyConstToAddMulConst
2841728418
}
2841828419
};
2841928420

28420-
// Match: sub(pad(x, 0, lo=0, hi=K, dim=d), pad(x, 0, lo=K, hi=0, dim=d))
28421-
// Rewrite to a convolution with kernel [-1, 1] and inline padding [K, K] on dim
28422-
// d. This handles the common "backward finite difference" pattern as a 1-D
28423-
// cross-correlation.
28424-
struct PadSubToConvolution
28425-
: public CheckedOpRewritePattern<stablehlo::SubtractOp,
28426-
PadSubToConvolution> {
28427-
using CheckedOpRewritePattern::CheckedOpRewritePattern;
28428-
28429-
LogicalResult matchAndRewriteImpl(stablehlo::SubtractOp op,
28430-
PatternRewriter &rewriter) const {
28431-
auto lhsPad = op.getLhs().getDefiningOp<stablehlo::PadOp>();
28432-
auto rhsPad = op.getRhs().getDefiningOp<stablehlo::PadOp>();
28433-
if (!lhsPad || !rhsPad)
28434-
return rewriter.notifyMatchFailure(op, "operands are not pad ops");
28435-
28436-
if (lhsPad.getOperand() != rhsPad.getOperand())
28437-
return rewriter.notifyMatchFailure(op, "pads have different operands");
28438-
28439-
if (anyPadSizesNegative(lhsPad) || anyPadSizesNegative(rhsPad))
28440-
return rewriter.notifyMatchFailure(op, "pads have negative sizes");
28441-
28442-
if (!llvm::all_of(lhsPad.getInteriorPadding(),
28443-
[](int64_t v) { return v == 0; }) ||
28444-
!llvm::all_of(rhsPad.getInteriorPadding(),
28445-
[](int64_t v) { return v == 0; }))
28446-
return rewriter.notifyMatchFailure(op, "interior padding is not zero");
28447-
28448-
if ((!matchPattern(lhsPad.getPaddingValue(), m_AnyZeroFloat()) &&
28449-
!matchPattern(lhsPad.getPaddingValue(), m_Zero())) ||
28450-
(!matchPattern(rhsPad.getPaddingValue(), m_AnyZeroFloat()) &&
28451-
!matchPattern(rhsPad.getPaddingValue(), m_Zero())))
28452-
return rewriter.notifyMatchFailure(op, "padding value is not zero");
28453-
28454-
auto lhsLow = lhsPad.getEdgePaddingLow();
28455-
auto lhsHigh = lhsPad.getEdgePaddingHigh();
28456-
auto rhsLow = rhsPad.getEdgePaddingLow();
28457-
auto rhsHigh = rhsPad.getEdgePaddingHigh();
28458-
28459-
auto outType = cast<RankedTensorType>(op.getType());
28460-
int64_t rank = outType.getRank();
28461-
28462-
// Find the single dimension with complementary padding; all others must be
28463-
// 0.
28464-
int64_t diffDim = -1;
28465-
int64_t shiftAmount = -1;
28466-
// lhsHasHighPad=true: lhs has hi=K, rhs has lo=K → x[i] - x[i-K]
28467-
// lhsHasHighPad=false: lhs has lo=K, rhs has hi=K → x[i-K] - x[i]
28468-
bool lhsHasHighPad = false;
28469-
28470-
for (int64_t d = 0; d < rank; d++) {
28471-
if (lhsLow[d] == 0 && lhsHigh[d] == 0 && rhsLow[d] == 0 &&
28472-
rhsHigh[d] == 0)
28473-
continue;
28474-
28475-
if (diffDim != -1)
28476-
return rewriter.notifyMatchFailure(
28477-
op, "more than one dimension differs in padding");
28478-
28479-
if (lhsLow[d] == 0 && rhsHigh[d] == 0 && lhsHigh[d] == rhsLow[d] &&
28480-
lhsHigh[d] > 0) {
28481-
diffDim = d;
28482-
shiftAmount = lhsHigh[d];
28483-
lhsHasHighPad = true;
28484-
} else if (lhsHigh[d] == 0 && rhsLow[d] == 0 && lhsLow[d] == rhsHigh[d] &&
28485-
lhsLow[d] > 0) {
28486-
diffDim = d;
28487-
shiftAmount = lhsLow[d];
28488-
lhsHasHighPad = false;
28489-
} else {
28490-
return rewriter.notifyMatchFailure(
28491-
op, "padding is not complementary in differing dimension");
28492-
}
28493-
}
28494-
28495-
if (diffDim == -1)
28496-
return rewriter.notifyMatchFailure(op, "no differencing dimension found");
28497-
28498-
auto loc = op.getLoc();
28499-
auto T = outType.getElementType();
28500-
auto scalarType = RankedTensorType::get({}, T);
28501-
28502-
// Reshape input x: [d0,...,dN] -> [1, 1, d0,...,dN]
28503-
auto inputType = cast<RankedTensorType>(lhsPad.getOperand().getType());
28504-
SmallVector<int64_t> convInputShape(2, 1);
28505-
for (auto d : inputType.getShape())
28506-
convInputShape.push_back(d);
28507-
auto convInput = stablehlo::ReshapeOpCreate(
28508-
rewriter, loc, lhsPad.getOperand(), convInputShape);
28509-
28510-
// Build kernel: two scalar elements reshaped then concatenated along
28511-
// diffDim+2. lhsHasHighPad=true → [-1, 1] lhsHasHighPad=false → [1, -1]
28512-
auto negOne = stablehlo::ConstantOp::create(
28513-
rewriter, loc, scalarType,
28514-
cast<ElementsAttr>(makeAttr(scalarType, -1)));
28515-
auto posOne = stablehlo::ConstantOp::create(
28516-
rewriter, loc, scalarType, cast<ElementsAttr>(makeAttr(scalarType, 1)));
28517-
28518-
SmallVector<int64_t> filterElemShape(rank + 2, 1);
28519-
auto negOneReshaped =
28520-
stablehlo::ReshapeOpCreate(rewriter, loc, negOne, filterElemShape);
28521-
auto posOneReshaped =
28522-
stablehlo::ReshapeOpCreate(rewriter, loc, posOne, filterElemShape);
28523-
28524-
Value firstElem = lhsHasHighPad ? negOneReshaped : posOneReshaped;
28525-
Value secondElem = lhsHasHighPad ? posOneReshaped : negOneReshaped;
28526-
28527-
auto filter = stablehlo::ConcatenateOp::create(
28528-
rewriter, loc, ValueRange{firstElem, secondElem},
28529-
rewriter.getI64IntegerAttr(diffDim + 2));
28530-
28531-
// Spatial dims: [2, 3, ..., rank+1]
28532-
SmallVector<int64_t> spatialDims(rank);
28533-
for (int64_t i = 0; i < rank; ++i)
28534-
spatialDims[i] = i + 2;
28535-
28536-
auto convDims = stablehlo::ConvDimensionNumbersAttr::get(
28537-
rewriter.getContext(),
28538-
/*input_batch_dimension=*/0,
28539-
/*input_feature_dimension=*/1,
28540-
/*input_spatial_dimensions=*/spatialDims,
28541-
/*kernel_input_feature_dimension=*/0,
28542-
/*kernel_output_feature_dimension=*/1,
28543-
/*kernel_spatial_dimensions=*/spatialDims,
28544-
/*output_batch_dimension=*/0,
28545-
/*output_feature_dimension=*/1,
28546-
/*output_spatial_dimensions=*/spatialDims);
28547-
28548-
// Inline padding: [shiftAmount, shiftAmount] on diffDim, 0 elsewhere.
28549-
SmallVector<int64_t> paddingVals(2 * rank, 0);
28550-
paddingVals[2 * diffDim] = shiftAmount;
28551-
paddingVals[2 * diffDim + 1] = shiftAmount;
28552-
auto paddingType = RankedTensorType::get({rank, 2}, rewriter.getI64Type());
28553-
auto paddingAttr = DenseIntElementsAttr::get(paddingType, paddingVals);
28554-
28555-
SmallVector<int64_t> convOutShape(2, 1);
28556-
for (auto d : outType.getShape())
28557-
convOutShape.push_back(d);
28558-
auto convOutType = RankedTensorType::get(convOutShape, T);
28559-
28560-
auto conv = stablehlo::ConvolutionOp::create(
28561-
rewriter, loc, convOutType, convInput, filter,
28562-
/*window_strides=*/nullptr,
28563-
/*padding=*/paddingAttr,
28564-
/*lhs_dilation=*/nullptr,
28565-
/*rhs_dilation=*/nullptr,
28566-
/*window_reversal=*/nullptr,
28567-
/*conv_dimension_numbers=*/convDims,
28568-
/*feature_group_count=*/rewriter.getI64IntegerAttr(1),
28569-
/*batch_group_count=*/rewriter.getI64IntegerAttr(1),
28570-
/*precision_config=*/nullptr);
28571-
28572-
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(), conv);
28573-
return success();
28574-
}
28575-
};
28576-
2857728421
template <typename OpTy>
2857828422
struct SelfElementwiseToConvolutionLike
2857928423
: public CheckedOpRewritePattern<OpTy,
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: enzymexlamlir-opt %s --enzyme-hlo-ops=passses=131072 | FileCheck %s
2+
3+
module {
4+
func.func @main(%arg6: tensor<1536xf64>, %3183: tensor<1519x3056xf64>) -> (tensor<1519x3056xf64>, tensor<1519x3056xf64>) {
5+
%3187 = stablehlo.slice %3183 [0:1, 0:3056] : (tensor<1519x3056xf64>) -> tensor<1x3056xf64>
6+
%3186 = stablehlo.slice %arg6 [9:10] : (tensor<1536xf64>) -> tensor<1xf64>
7+
%3196 = stablehlo.broadcast_in_dim %3186, dims = [0] : (tensor<1xf64>) -> tensor<1x3056xf64>
8+
%3197 = stablehlo.multiply %3196, %3187 : tensor<1x3056xf64>
9+
10+
%3192 = stablehlo.slice %arg6 [9:1528] : (tensor<1536xf64>) -> tensor<1519xf64>
11+
%3193 = stablehlo.broadcast_in_dim %3192, dims = [0] : (tensor<1519xf64>) -> tensor<1519x3056xf64>
12+
%3195 = stablehlo.multiply %3193, %3183 : tensor<1519x3056xf64>
13+
%3198 = stablehlo.slice %3195 [1:1519, 0:3056] : (tensor<1519x3056xf64>) -> tensor<1518x3056xf64>
14+
%3199 = stablehlo.concatenate %3197, %3198, dim = 0 : (tensor<1x3056xf64>, tensor<1518x3056xf64>) -> tensor<1519x3056xf64>
15+
16+
return %3199, %3195 : tensor<1519x3056xf64>, tensor<1519x3056xf64>
17+
}
18+
}
19+
20+
// CHECK: func.func @main(%arg0: tensor<1536xf64>, %arg1: tensor<1519x3056xf64>) -> (tensor<1519x3056xf64>, tensor<1519x3056xf64>) {
21+
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [9:1528] : (tensor<1536xf64>) -> tensor<1519xf64>
22+
// CHECK-NEXT: %1 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<1519xf64>) -> tensor<1519x3056xf64>
23+
// CHECK-NEXT: %2 = stablehlo.multiply %1, %arg1 : tensor<1519x3056xf64>
24+
// CHECK-NEXT: return %2, %2 : tensor<1519x3056xf64>, tensor<1519x3056xf64>
25+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)