Skip to content

Commit bca1e1c

Browse files
authored
feat: expand coverage of noop_reverse (#1568)
1 parent 312f09b commit bca1e1c

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13090,6 +13090,12 @@ struct NoopReverse final
1309013090

1309113091
LogicalResult matchAndRewriteImpl(stablehlo::ReverseOp op,
1309213092
PatternRewriter &rewriter) const {
13093+
SplatElementsAttr splat;
13094+
if (matchPattern(op.getOperand(), m_Constant(&splat))) {
13095+
rewriter.replaceAllUsesWith(op, op.getOperand());
13096+
return success();
13097+
}
13098+
1309313099
SmallVector<int64_t> newDimensions;
1309413100
auto dimensions = op.getDimensions();
1309513101
auto shape = op.getResult().getType().getShape();
@@ -13100,6 +13106,11 @@ struct NoopReverse final
1310013106
newDimensions.push_back(dim);
1310113107
}
1310213108

13109+
if (auto bcast =
13110+
op.getOperand().getDefiningOp<stablehlo::BroadcastInDimOp>()) {
13111+
peelBroadcastedDimensions(bcast, newDimensions);
13112+
}
13113+
1310313114
if (newDimensions.empty()) {
1310413115
rewriter.replaceOp(op, op.getOperand());
1310513116
return success();
@@ -13112,6 +13123,27 @@ struct NoopReverse final
1311213123
newDimensions);
1311313124
return success();
1311413125
}
13126+
13127+
private:
13128+
void peelBroadcastedDimensions(stablehlo::BroadcastInDimOp op,
13129+
SmallVectorImpl<int64_t> &dims) const {
13130+
DenseMap<int64_t, int64_t> dimMap;
13131+
for (auto [i, dim] : llvm::enumerate(op.getBroadcastDimensions())) {
13132+
dimMap[dim] = i;
13133+
}
13134+
13135+
auto opShape = cast<RankedTensorType>(op.getOperand().getType()).getShape();
13136+
13137+
auto newEnd = llvm::remove_if(dims, [&](int64_t dim) {
13138+
auto it = dimMap.find(dim);
13139+
if (it != dimMap.end()) {
13140+
return opShape[it->second] == 1; // if 1 then trivially expanded
13141+
}
13142+
return true; // not in broadcast dims so it was expanded
13143+
});
13144+
dims.erase(newEnd, dims.end());
13145+
return;
13146+
}
1311513147
};
1311613148

1311713149
/// Converts gather ops to slice ops in case we have a single set of constant
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=noop_reverse" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
2+
3+
module {
4+
func.func @main() -> tensor<8x4x3xf32> {
5+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<8x4x3xf32>
6+
%1 = stablehlo.reverse %cst, dims = [2, 1] : tensor<8x4x3xf32>
7+
return %1 : tensor<8x4x3xf32>
8+
// CHECK: %cst = stablehlo.constant dense<0.000000e+00> : tensor<8x4x3xf32>
9+
// CHECK-NEXT: return %cst : tensor<8x4x3xf32>
10+
}
11+
}
12+
13+
module {
14+
func.func @main(%arg0: tensor<8x1xf32>) -> tensor<8x4x3x1xf32> {
15+
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 3] : (tensor<8x1xf32>) -> tensor<8x4x3x1xf32>
16+
%1 = stablehlo.reverse %0, dims = [3, 2, 0] : tensor<8x4x3x1xf32>
17+
return %1 : tensor<8x4x3x1xf32>
18+
19+
// CHECK: %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 3] : (tensor<8x1xf32>) -> tensor<8x4x3x1xf32>
20+
// CHECK-NEXT: %1 = stablehlo.reverse %0, dims = [0] : tensor<8x4x3x1xf32>
21+
// CHECK-NEXT: return %1 : tensor<8x4x3x1xf32>
22+
}
23+
}
24+
25+
module {
26+
func.func @main(%arg0: tensor<8x1xf32>) -> tensor<8x4x3x1xf32> {
27+
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 3] : (tensor<8x1xf32>) -> tensor<8x4x3x1xf32>
28+
%1 = stablehlo.reverse %0, dims = [2, 0] : tensor<8x4x3x1xf32>
29+
return %1 : tensor<8x4x3x1xf32>
30+
31+
// CHECK: %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 3] : (tensor<8x1xf32>) -> tensor<8x4x3x1xf32>
32+
// CHECK-NEXT: %1 = stablehlo.reverse %0, dims = [0] : tensor<8x4x3x1xf32>
33+
// CHECK-NEXT: return %1 : tensor<8x4x3x1xf32>
34+
}
35+
}
36+
37+
module {
38+
func.func @main(%arg0: tensor<1x8xf32>) -> tensor<1x8xf32> {
39+
%0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<1x8xf32>
40+
return %0 : tensor<1x8xf32>
41+
42+
// CHECK: %0 = stablehlo.reverse %arg0, dims = [1] : tensor<1x8xf32>
43+
// CHECK-NEXT: return %0 : tensor<1x8xf32>
44+
}
45+
}

0 commit comments

Comments
 (0)