@@ -3683,7 +3683,10 @@ struct SliceElementwise final
36833683 return failure();
36843684 if (!stablehlo::hasTraitElementwise(elem))
36853685 return failure();
3686- if (llvm::hasSingleElement(elem->getUsers())) {
3686+ if (llvm::hasSingleElement(elem->getUsers()) ||
3687+ (llvm::hasSingleElement(op.getResult().getUsers()) &&
3688+ isa<stablehlo::ConcatenateOp>(
3689+ op.getResult().use_begin()->getOwner()))) {
36873690 SmallVector<Value> ops;
36883691 for (auto v : elem->getOperands()) {
36893692 ops.push_back(stablehlo::SliceOp::create(
@@ -13517,8 +13520,7 @@ struct SelectBroadcastIota final
1351713520 return failure();
1351813521
1351913522 // broadcast input must be a compare
13520- auto compare =
13521- broadcast.getOperand().getDefiningOp<stablehlo::CompareOp>();
13523+ auto compare = broadcast.getOperand().getDefiningOp<stablehlo::CompareOp>();
1352213524 if (!compare)
1352313525 return failure();
1352413526
@@ -13639,9 +13641,8 @@ struct SelectBroadcastIota final
1363913641 break;
1364013642 }
1364113643
13642- auto validCount =
13643- std::count_if(slices.begin(), slices.end(),
13644- [](slice_data d) { return d.count > 0; });
13644+ auto validCount = std::count_if(slices.begin(), slices.end(),
13645+ [](slice_data d) { return d.count > 0; });
1364513646 if (validCount == 1) {
1364613647 for (auto &e : slices)
1364713648 if (e.count > 0) {
@@ -28417,163 +28418,6 @@ struct SubtractMultiplyConstToAddMulConst
2841728418 }
2841828419};
2841928420
28420- // Match: sub(pad(x, 0, lo=0, hi=K, dim=d), pad(x, 0, lo=K, hi=0, dim=d))
28421- // Rewrite to a convolution with kernel [-1, 1] and inline padding [K, K] on dim
28422- // d. This handles the common "backward finite difference" pattern as a 1-D
28423- // cross-correlation.
28424- struct PadSubToConvolution
28425- : public CheckedOpRewritePattern<stablehlo::SubtractOp,
28426- PadSubToConvolution> {
28427- using CheckedOpRewritePattern::CheckedOpRewritePattern;
28428-
28429- LogicalResult matchAndRewriteImpl(stablehlo::SubtractOp op,
28430- PatternRewriter &rewriter) const {
28431- auto lhsPad = op.getLhs().getDefiningOp<stablehlo::PadOp>();
28432- auto rhsPad = op.getRhs().getDefiningOp<stablehlo::PadOp>();
28433- if (!lhsPad || !rhsPad)
28434- return rewriter.notifyMatchFailure(op, "operands are not pad ops");
28435-
28436- if (lhsPad.getOperand() != rhsPad.getOperand())
28437- return rewriter.notifyMatchFailure(op, "pads have different operands");
28438-
28439- if (anyPadSizesNegative(lhsPad) || anyPadSizesNegative(rhsPad))
28440- return rewriter.notifyMatchFailure(op, "pads have negative sizes");
28441-
28442- if (!llvm::all_of(lhsPad.getInteriorPadding(),
28443- [](int64_t v) { return v == 0; }) ||
28444- !llvm::all_of(rhsPad.getInteriorPadding(),
28445- [](int64_t v) { return v == 0; }))
28446- return rewriter.notifyMatchFailure(op, "interior padding is not zero");
28447-
28448- if ((!matchPattern(lhsPad.getPaddingValue(), m_AnyZeroFloat()) &&
28449- !matchPattern(lhsPad.getPaddingValue(), m_Zero())) ||
28450- (!matchPattern(rhsPad.getPaddingValue(), m_AnyZeroFloat()) &&
28451- !matchPattern(rhsPad.getPaddingValue(), m_Zero())))
28452- return rewriter.notifyMatchFailure(op, "padding value is not zero");
28453-
28454- auto lhsLow = lhsPad.getEdgePaddingLow();
28455- auto lhsHigh = lhsPad.getEdgePaddingHigh();
28456- auto rhsLow = rhsPad.getEdgePaddingLow();
28457- auto rhsHigh = rhsPad.getEdgePaddingHigh();
28458-
28459- auto outType = cast<RankedTensorType>(op.getType());
28460- int64_t rank = outType.getRank();
28461-
28462- // Find the single dimension with complementary padding; all others must be
28463- // 0.
28464- int64_t diffDim = -1;
28465- int64_t shiftAmount = -1;
28466- // lhsHasHighPad=true: lhs has hi=K, rhs has lo=K → x[i] - x[i-K]
28467- // lhsHasHighPad=false: lhs has lo=K, rhs has hi=K → x[i-K] - x[i]
28468- bool lhsHasHighPad = false;
28469-
28470- for (int64_t d = 0; d < rank; d++) {
28471- if (lhsLow[d] == 0 && lhsHigh[d] == 0 && rhsLow[d] == 0 &&
28472- rhsHigh[d] == 0)
28473- continue;
28474-
28475- if (diffDim != -1)
28476- return rewriter.notifyMatchFailure(
28477- op, "more than one dimension differs in padding");
28478-
28479- if (lhsLow[d] == 0 && rhsHigh[d] == 0 && lhsHigh[d] == rhsLow[d] &&
28480- lhsHigh[d] > 0) {
28481- diffDim = d;
28482- shiftAmount = lhsHigh[d];
28483- lhsHasHighPad = true;
28484- } else if (lhsHigh[d] == 0 && rhsLow[d] == 0 && lhsLow[d] == rhsHigh[d] &&
28485- lhsLow[d] > 0) {
28486- diffDim = d;
28487- shiftAmount = lhsLow[d];
28488- lhsHasHighPad = false;
28489- } else {
28490- return rewriter.notifyMatchFailure(
28491- op, "padding is not complementary in differing dimension");
28492- }
28493- }
28494-
28495- if (diffDim == -1)
28496- return rewriter.notifyMatchFailure(op, "no differencing dimension found");
28497-
28498- auto loc = op.getLoc();
28499- auto T = outType.getElementType();
28500- auto scalarType = RankedTensorType::get({}, T);
28501-
28502- // Reshape input x: [d0,...,dN] -> [1, 1, d0,...,dN]
28503- auto inputType = cast<RankedTensorType>(lhsPad.getOperand().getType());
28504- SmallVector<int64_t> convInputShape(2, 1);
28505- for (auto d : inputType.getShape())
28506- convInputShape.push_back(d);
28507- auto convInput = stablehlo::ReshapeOpCreate(
28508- rewriter, loc, lhsPad.getOperand(), convInputShape);
28509-
28510- // Build kernel: two scalar elements reshaped then concatenated along
28511- // diffDim+2. lhsHasHighPad=true → [-1, 1] lhsHasHighPad=false → [1, -1]
28512- auto negOne = stablehlo::ConstantOp::create(
28513- rewriter, loc, scalarType,
28514- cast<ElementsAttr>(makeAttr(scalarType, -1)));
28515- auto posOne = stablehlo::ConstantOp::create(
28516- rewriter, loc, scalarType, cast<ElementsAttr>(makeAttr(scalarType, 1)));
28517-
28518- SmallVector<int64_t> filterElemShape(rank + 2, 1);
28519- auto negOneReshaped =
28520- stablehlo::ReshapeOpCreate(rewriter, loc, negOne, filterElemShape);
28521- auto posOneReshaped =
28522- stablehlo::ReshapeOpCreate(rewriter, loc, posOne, filterElemShape);
28523-
28524- Value firstElem = lhsHasHighPad ? negOneReshaped : posOneReshaped;
28525- Value secondElem = lhsHasHighPad ? posOneReshaped : negOneReshaped;
28526-
28527- auto filter = stablehlo::ConcatenateOp::create(
28528- rewriter, loc, ValueRange{firstElem, secondElem},
28529- rewriter.getI64IntegerAttr(diffDim + 2));
28530-
28531- // Spatial dims: [2, 3, ..., rank+1]
28532- SmallVector<int64_t> spatialDims(rank);
28533- for (int64_t i = 0; i < rank; ++i)
28534- spatialDims[i] = i + 2;
28535-
28536- auto convDims = stablehlo::ConvDimensionNumbersAttr::get(
28537- rewriter.getContext(),
28538- /*input_batch_dimension=*/0,
28539- /*input_feature_dimension=*/1,
28540- /*input_spatial_dimensions=*/spatialDims,
28541- /*kernel_input_feature_dimension=*/0,
28542- /*kernel_output_feature_dimension=*/1,
28543- /*kernel_spatial_dimensions=*/spatialDims,
28544- /*output_batch_dimension=*/0,
28545- /*output_feature_dimension=*/1,
28546- /*output_spatial_dimensions=*/spatialDims);
28547-
28548- // Inline padding: [shiftAmount, shiftAmount] on diffDim, 0 elsewhere.
28549- SmallVector<int64_t> paddingVals(2 * rank, 0);
28550- paddingVals[2 * diffDim] = shiftAmount;
28551- paddingVals[2 * diffDim + 1] = shiftAmount;
28552- auto paddingType = RankedTensorType::get({rank, 2}, rewriter.getI64Type());
28553- auto paddingAttr = DenseIntElementsAttr::get(paddingType, paddingVals);
28554-
28555- SmallVector<int64_t> convOutShape(2, 1);
28556- for (auto d : outType.getShape())
28557- convOutShape.push_back(d);
28558- auto convOutType = RankedTensorType::get(convOutShape, T);
28559-
28560- auto conv = stablehlo::ConvolutionOp::create(
28561- rewriter, loc, convOutType, convInput, filter,
28562- /*window_strides=*/nullptr,
28563- /*padding=*/paddingAttr,
28564- /*lhs_dilation=*/nullptr,
28565- /*rhs_dilation=*/nullptr,
28566- /*window_reversal=*/nullptr,
28567- /*conv_dimension_numbers=*/convDims,
28568- /*feature_group_count=*/rewriter.getI64IntegerAttr(1),
28569- /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
28570- /*precision_config=*/nullptr);
28571-
28572- rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op.getType(), conv);
28573- return success();
28574- }
28575- };
28576-
2857728421template <typename OpTy>
2857828422struct SelfElementwiseToConvolutionLike
2857928423 : public CheckedOpRewritePattern<OpTy,
0 commit comments