Skip to content

Commit 14bfc93

Browse files
authored
feat: dus to dynamic_pad + dynamic_pad to pad (#1430)
* feat: dus to dynamic_pad * feat: dynamic_pad to pad * chore: comments * test: restrict applicability * feat: add to primitives
1 parent c8efb46 commit 14bfc93

File tree

4 files changed

+287
-2
lines changed

4 files changed

+287
-2
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 200 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5685,6 +5685,52 @@ struct UnaryConstProp final
56855685
}
56865686
};
56875687

5688+
struct ClampConstProp final
5689+
: CheckedOpRewritePattern<stablehlo::ClampOp, ClampConstProp> {
5690+
using CheckedOpRewritePattern::CheckedOpRewritePattern;
5691+
5692+
LogicalResult matchAndRewriteImpl(stablehlo::ClampOp op,
5693+
PatternRewriter &rewriter) const {
5694+
DenseElementsAttr minAttr, inputAttr, maxAttr;
5695+
if (!matchPattern(op.getMin(), m_Constant(&minAttr)) ||
5696+
!matchPattern(op.getOperand(), m_Constant(&inputAttr)) ||
5697+
!matchPattern(op.getMax(), m_Constant(&maxAttr)))
5698+
return failure();
5699+
5700+
// TODO: for only min or max with input being constant we can convert this
5701+
// to a min/max op
5702+
stablehlo::Tensor minTen, maxTen, inputTen;
5703+
bool splattedVersion = false;
5704+
RankedTensorType ty = cast<RankedTensorType>(op->getResultTypes()[0]);
5705+
if (minAttr.isSplat() && maxAttr.isSplat() && inputAttr.isSplat()) {
5706+
splattedVersion = true;
5707+
ty = RankedTensorType::get(
5708+
{}, cast<ShapedType>(op->getResultTypes()[0]).getElementType());
5709+
auto inputTy = RankedTensorType::get(
5710+
{}, cast<ShapedType>(op->getOperand(0).getType()).getElementType());
5711+
minTen = stablehlo::makeTensor(minAttr.resizeSplat(inputTy));
5712+
maxTen = stablehlo::makeTensor(maxAttr.resizeSplat(inputTy));
5713+
inputTen = stablehlo::makeTensor(inputAttr.resizeSplat(inputTy));
5714+
} else {
5715+
minTen = stablehlo::constantOp(minAttr);
5716+
maxTen = stablehlo::constantOp(maxAttr);
5717+
inputTen = stablehlo::constantOp(inputAttr);
5718+
}
5719+
5720+
auto out =
5721+
fromTensor(clampOp(inputTen, minTen, maxTen, cast<ShapedType>(ty)));
5722+
5723+
if (splattedVersion) {
5724+
out = out.resizeSplat(cast<ShapedType>(op->getResultTypes()[0]));
5725+
}
5726+
// Replace with new constant op containing the computed result
5727+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
5728+
op, op->getResultTypes()[0], out);
5729+
5730+
return success();
5731+
}
5732+
};
5733+
56885734
struct ChloInfConstProp final
56895735
: CheckedOpRewritePattern<chlo::IsInfOp, ChloInfConstProp> {
56905736
using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -23136,6 +23182,156 @@ struct CaseToIf : public CheckedOpRewritePattern<stablehlo::CaseOp, CaseToIf> {
2313623182
}
2313723183
};
2313823184

23185+
struct DUSToDynamicPad
23186+
: public CheckedOpRewritePattern<stablehlo::DynamicUpdateSliceOp,
23187+
DUSToDynamicPad> {
23188+
using CheckedOpRewritePattern<stablehlo::DynamicUpdateSliceOp,
23189+
DUSToDynamicPad>::CheckedOpRewritePattern;
23190+
23191+
LogicalResult matchAndRewriteImpl(stablehlo::DynamicUpdateSliceOp op,
23192+
PatternRewriter &rewriter) const {
23193+
auto operand = op.getOperand();
23194+
auto update = op.getUpdate();
23195+
auto indices = op.getStartIndices();
23196+
23197+
for (auto [i, index] : llvm::enumerate(indices)) {
23198+
if (!matchPattern(index, m_Constant())) {
23199+
return rewriter.notifyMatchFailure(
23200+
op, "not all indices are constant. currently we don't support this "
23201+
"case");
23202+
}
23203+
}
23204+
23205+
Value scalarOperand = getScalarPadValue(rewriter, operand);
23206+
if (!scalarOperand)
23207+
return rewriter.notifyMatchFailure(op, "operand is not a scalar pad");
23208+
23209+
auto updateShape = cast<RankedTensorType>(update.getType()).getShape();
23210+
auto operandShape = cast<RankedTensorType>(operand.getType()).getShape();
23211+
23212+
SmallVector<Value> edgePaddingLowValues, edgePaddingHighValues;
23213+
for (auto [i, index] : llvm::enumerate(indices)) {
23214+
auto cType = RankedTensorType::get(
23215+
{}, cast<RankedTensorType>(index.getType()).getElementType());
23216+
auto clampedIndex = rewriter.create<stablehlo::ClampOp>(
23217+
op.getLoc(),
23218+
rewriter.create<stablehlo::ConstantOp>(
23219+
op.getLoc(), cType, cast<ElementsAttr>(makeAttr(cType, 0))),
23220+
index,
23221+
rewriter.create<stablehlo::ConstantOp>(
23222+
op.getLoc(), cType,
23223+
cast<ElementsAttr>(
23224+
makeAttr(cType, operandShape[i] - updateShape[i]))));
23225+
23226+
auto reshapedIndex = rewriter.create<stablehlo::ReshapeOp>(
23227+
op.getLoc(),
23228+
RankedTensorType::get(
23229+
{1}, cast<RankedTensorType>(index.getType()).getElementType()),
23230+
clampedIndex);
23231+
edgePaddingLowValues.push_back(reshapedIndex.getResult());
23232+
23233+
auto iType = RankedTensorType::get(
23234+
{1}, cast<RankedTensorType>(index.getType()).getElementType());
23235+
auto tmp = rewriter.create<stablehlo::ConstantOp>(
23236+
op.getLoc(), iType,
23237+
cast<ElementsAttr>(
23238+
makeAttr(iType, operandShape[i] - updateShape[i])));
23239+
auto paddingHigh = rewriter.create<stablehlo::SubtractOp>(
23240+
op.getLoc(), tmp, reshapedIndex);
23241+
edgePaddingHighValues.push_back(paddingHigh);
23242+
}
23243+
23244+
auto edgePaddingLow = rewriter.create<stablehlo::ConcatenateOp>(
23245+
op.getLoc(), edgePaddingLowValues, 0);
23246+
auto edgePaddingHigh = rewriter.create<stablehlo::ConcatenateOp>(
23247+
op.getLoc(), edgePaddingHighValues, 0);
23248+
auto interiorPadding = rewriter.create<stablehlo::ConstantOp>(
23249+
op.getLoc(), edgePaddingLow.getType(),
23250+
cast<ElementsAttr>(makeAttr(edgePaddingLow.getType(), 0)));
23251+
23252+
rewriter.replaceOpWithNewOp<stablehlo::DynamicPadOp>(
23253+
op, op.getType(), update, scalarOperand, edgePaddingLow,
23254+
edgePaddingHigh, interiorPadding);
23255+
return success();
23256+
}
23257+
23258+
private:
23259+
Value getScalarPadValue(PatternRewriter &rewriter, Value operand) const {
23260+
Value scalarOperand = getScalarPadValueViaBcastInDim(rewriter, operand);
23261+
if (scalarOperand)
23262+
return scalarOperand;
23263+
23264+
scalarOperand = getScalarPadValueViaSplattedConstant(rewriter, operand);
23265+
if (scalarOperand)
23266+
return scalarOperand;
23267+
23268+
return nullptr;
23269+
}
23270+
23271+
Value getScalarPadValueViaBcastInDim(PatternRewriter &rewriter,
23272+
Value operand) const {
23273+
auto bcastInDimOp = operand.getDefiningOp<stablehlo::BroadcastInDimOp>();
23274+
if (!bcastInDimOp)
23275+
return nullptr;
23276+
23277+
auto bcastOperand = bcastInDimOp.getOperand();
23278+
auto bcastOperandType = cast<RankedTensorType>(bcastOperand.getType());
23279+
if (bcastOperandType.getRank() != 0)
23280+
return nullptr;
23281+
23282+
return bcastOperand;
23283+
}
23284+
23285+
Value getScalarPadValueViaSplattedConstant(PatternRewriter &rewriter,
23286+
Value operand) const {
23287+
SplatElementsAttr splatAttr;
23288+
if (!matchPattern(operand, m_Constant(&splatAttr)))
23289+
return nullptr;
23290+
23291+
return rewriter.create<stablehlo::ConstantOp>(
23292+
operand.getLoc(), splatAttr.getSplatValue<Attribute>());
23293+
}
23294+
};
23295+
23296+
struct DynamicPadToPad
23297+
: public CheckedOpRewritePattern<stablehlo::DynamicPadOp, DynamicPadToPad> {
23298+
using CheckedOpRewritePattern<stablehlo::DynamicPadOp,
23299+
DynamicPadToPad>::CheckedOpRewritePattern;
23300+
23301+
LogicalResult matchAndRewriteImpl(stablehlo::DynamicPadOp op,
23302+
PatternRewriter &rewriter) const {
23303+
auto operand = op.getOperand();
23304+
auto paddingValue = op.getPaddingValue();
23305+
auto edgePaddingLow = op.getEdgePaddingLow();
23306+
auto edgePaddingHigh = op.getEdgePaddingHigh();
23307+
auto interiorPadding = op.getInteriorPadding();
23308+
23309+
DenseIntElementsAttr edgePaddingLowAttr, edgePaddingHighAttr,
23310+
interiorPaddingAttr;
23311+
if (!matchPattern(edgePaddingLow, m_Constant(&edgePaddingLowAttr)) ||
23312+
!matchPattern(edgePaddingHigh, m_Constant(&edgePaddingHighAttr)) ||
23313+
!matchPattern(interiorPadding, m_Constant(&interiorPaddingAttr)))
23314+
return rewriter.notifyMatchFailure(op, "edge padding is not a constant");
23315+
23316+
rewriter.replaceOpWithNewOp<stablehlo::PadOp>(
23317+
op, op.getType(), operand, paddingValue,
23318+
convertToDenseI64ArrayAttr(edgePaddingLowAttr),
23319+
convertToDenseI64ArrayAttr(edgePaddingHighAttr),
23320+
convertToDenseI64ArrayAttr(interiorPaddingAttr));
23321+
return success();
23322+
}
23323+
23324+
private:
23325+
DenseI64ArrayAttr
23326+
convertToDenseI64ArrayAttr(DenseIntElementsAttr attr) const {
23327+
auto values = attr.getValues<APInt>();
23328+
llvm::SmallVector<int64_t> denseValues;
23329+
for (auto value : values)
23330+
denseValues.push_back(value.getSExtValue());
23331+
return DenseI64ArrayAttr::get(attr.getContext(), denseValues);
23332+
}
23333+
};
23334+
2313923335
/////////////// End Imported from stablehlo
2314023336

2314123337
// clang-format off
@@ -23484,7 +23680,7 @@ struct EnzymeHLOOptPass
2348423680
BinaryConstProp<stablehlo::SubtractOp, stablehlo::subtractOp>,
2348523681
BinaryConstProp<stablehlo::XorOp, stablehlo::xorOp>>(context);
2348623682

23487-
patterns.add<GatherConstProp>(context);
23683+
patterns.add<GatherConstProp, ClampConstProp>(context);
2348823684

2348923685
patterns.add<BinaryOpTransposeSimplify<stablehlo::AddOp>,
2349023686
BinaryOpTransposeSimplify<stablehlo::SubtractOp>,
@@ -23747,7 +23943,9 @@ struct EnzymeHLOOptPass
2374723943
MulReduceSliceFusion,
2374823944
MinReduceSliceFusion,
2374923945
MaxReduceSliceFusion,
23750-
CaseToIf
23946+
CaseToIf,
23947+
DUSToDynamicPad,
23948+
DynamicPadToPad
2375123949
>(context);
2375223950

2375323951
patterns.add<

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,3 +2383,14 @@ def ApplyElementwiseSliceToBatchPatterns : EnzymeHLOPatternOp<
23832383
def ApplyCaseToIfPatterns : EnzymeHLOPatternOp<"case_to_if"> {
23842384
let patterns = ["CaseToIf"];
23852385
}
2386+
2387+
def DUSToDynamicPad : EnzymeHLOPatternOp<"dus_to_dynamic_pad"> {
2388+
let patterns = ["DUSToDynamicPad"];
2389+
}
2390+
def DynamicPadToPad : EnzymeHLOPatternOp<"dynamic_pad_to_pad"> {
2391+
let patterns = ["DynamicPadToPad"];
2392+
}
2393+
2394+
def ApplyClampConstPropPatterns : EnzymeHLOPatternOp<"clamp_const_prop"> {
2395+
let patterns = ["ClampConstProp"];
2396+
}

src/enzyme_ad/jax/primitives.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def optimization_passes(
215215
"binop_const_pad_subtract<1>",
216216
"binop_const_pad_mul<1>",
217217
"binop_const_pad_div<1>",
218+
"clamp_const_prop<1>",
218219
"binop_binop_pad_pad_add<1>",
219220
"binop_binop_pad_pad_mul<1>",
220221
"binop_pad_pad_add<1>",
@@ -351,6 +352,8 @@ def optimization_passes(
351352
"concatenate_subtract_to_subtract_pad",
352353
"concatenate_broadcast_in_dim",
353354
"case_to_if",
355+
"dus_to_dynamic_pad",
356+
"dynamic_pad_to_pad",
354357
]
355358

356359
# constant propagation patterns

test/lit_tests/dus_to_pad.mlir

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s
2+
3+
func.func @main1(%arg0: tensor<f32>) -> tensor<5x3x4xf32> attributes {enzymexla.memory_effects = []} {
4+
%c = stablehlo.constant dense<0> : tensor<i32>
5+
%c_0 = stablehlo.constant dense<1> : tensor<i32>
6+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<5x3x4xf32>
7+
%0 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<2x1x4xf32>
8+
%1 = stablehlo.dynamic_update_slice %cst, %0, %c_0, %c_0, %c : (tensor<5x3x4xf32>, tensor<2x1x4xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5x3x4xf32>
9+
return %1 : tensor<5x3x4xf32>
10+
}
11+
12+
// CHECK: func.func @main1(%arg0: tensor<f32>) -> tensor<5x3x4xf32> attributes {enzymexla.memory_effects = []} {
13+
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
14+
// CHECK-NEXT: %0 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<2x1x4xf32>
15+
// CHECK-NEXT: %1 = stablehlo.pad %0, %cst, low = [1, 1, 0], high = [2, 1, 0], interior = [0, 0, 0] : (tensor<2x1x4xf32>, tensor<f32>) -> tensor<5x3x4xf32>
16+
// CHECK-NEXT: return %1 : tensor<5x3x4xf32>
17+
// CHECK-NEXT: }
18+
19+
func.func @main2(%arg0: tensor<f32>, %arg1: tensor<i64>) -> tensor<5x3x4xf32> attributes {enzymexla.memory_effects = []} {
20+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<5x3x4xf32>
21+
%c = stablehlo.constant dense<1> : tensor<i32>
22+
%c_0 = stablehlo.constant dense<0> : tensor<i32>
23+
%0 = stablehlo.convert %arg1 : (tensor<i64>) -> tensor<i32>
24+
%1 = stablehlo.subtract %0, %c : tensor<i32>
25+
%2 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<2x1x4xf32>
26+
%3 = stablehlo.dynamic_update_slice %cst, %2, %c, %1, %c_0 : (tensor<5x3x4xf32>, tensor<2x1x4xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5x3x4xf32>
27+
return %3 : tensor<5x3x4xf32>
28+
}
29+
30+
// CHECK: stablehlo.dynamic_update_slice
31+
32+
func.func @main3(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5x3x4xf32> attributes {enzymexla.memory_effects = []} {
33+
%c = stablehlo.constant dense<0> : tensor<i32>
34+
%c_0 = stablehlo.constant dense<1> : tensor<i32>
35+
%0 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<5x3x4xf32>
36+
%1 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f32>) -> tensor<2x1x4xf32>
37+
%2 = stablehlo.dynamic_update_slice %0, %1, %c_0, %c_0, %c : (tensor<5x3x4xf32>, tensor<2x1x4xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5x3x4xf32>
38+
return %2 : tensor<5x3x4xf32>
39+
}
40+
41+
// CHECK: func.func @main3(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5x3x4xf32> attributes {enzymexla.memory_effects = []} {
42+
// CHECK-NEXT: %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f32>) -> tensor<2x1x4xf32>
43+
// CHECK-NEXT: %1 = stablehlo.pad %0, %arg0, low = [1, 1, 0], high = [2, 1, 0], interior = [0, 0, 0] : (tensor<2x1x4xf32>, tensor<f32>) -> tensor<5x3x4xf32>
44+
// CHECK-NEXT: return %1 : tensor<5x3x4xf32>
45+
// CHECK-NEXT: }
46+
47+
func.func @main4(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i64>) -> tensor<5x3x4xf32> attributes {enzymexla.memory_effects = []} {
48+
%c = stablehlo.constant dense<1> : tensor<i32>
49+
%c_0 = stablehlo.constant dense<0> : tensor<i32>
50+
%0 = stablehlo.convert %arg2 : (tensor<i64>) -> tensor<i32>
51+
%1 = stablehlo.subtract %0, %c : tensor<i32>
52+
%2 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f32>) -> tensor<5x3x4xf32>
53+
%3 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f32>) -> tensor<2x1x4xf32>
54+
%4 = stablehlo.dynamic_update_slice %2, %3, %c, %1, %c_0 : (tensor<5x3x4xf32>, tensor<2x1x4xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5x3x4xf32>
55+
return %4 : tensor<5x3x4xf32>
56+
}
57+
58+
// CHECK: stablehlo.dynamic_update_slice
59+
60+
func.func @main5(%arg0: tensor<2x2x3xf32>) -> tensor<2x4x3xf32> attributes {enzymexla.memory_effects = []} {
61+
%c = stablehlo.constant dense<-3> : tensor<i32>
62+
%c_0 = stablehlo.constant dense<-1> : tensor<i32>
63+
%c_1 = stablehlo.constant dense<1> : tensor<i32>
64+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<2x4x3xf32>
65+
%0 = stablehlo.dynamic_update_slice %cst, %arg0, %c_1, %c_0, %c : (tensor<2x4x3xf32>, tensor<2x2x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<2x4x3xf32>
66+
return %0 : tensor<2x4x3xf32>
67+
}
68+
69+
// CHECK: func.func @main5(%arg0: tensor<2x2x3xf32>) -> tensor<2x4x3xf32> attributes {enzymexla.memory_effects = []} {
70+
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
71+
// CHECK-NEXT: %0 = stablehlo.pad %arg0, %cst, low = [0, 0, 0], high = [0, 2, 0], interior = [0, 0, 0] : (tensor<2x2x3xf32>, tensor<f32>) -> tensor<2x4x3xf32>
72+
// CHECK-NEXT: return %0 : tensor<2x4x3xf32>
73+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)