Skip to content

Commit b7ae6d3

Browse files
committed
fix: remove selectifactive
1 parent 055e07f commit b7ae6d3

File tree

2 files changed

+78
-15
lines changed

2 files changed

+78
-15
lines changed

src/enzyme_ad/jax/Implementations/HLODerivatives.td

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,15 +1459,31 @@ def getBroadcastDimensionsWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow
14591459

14601460
def BroadcastDimsToReductionDims : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
14611461
SmallVector<int64_t> reduceDims;
1462-
auto outRank = cast<RankedTensorType>(op.getType()).getRank();
1463-
for (int64_t i = 0; i < outRank; i++) {
1464-
if (!llvm::is_contained(op.getBroadcastDimensions(), i)) {
1465-
reduceDims.push_back(i);
1462+
auto outTy = cast<RankedTensorType>(op.getType());
1463+
auto bcastDims = op.getBroadcastDimensions();
1464+
auto inTy = cast<RankedTensorType>(op.getOperand().getType());
1465+
1466+
for (auto en : llvm::enumerate(outTy.getShape())) {
1467+
ssize_t bcastIdx = -1;
1468+
for (auto en2 : llvm::enumerate(bcastDims)) {
1469+
if (en2.value() == en.index()) {
1470+
bcastIdx = en2.index();
1471+
break;
1472+
}
14661473
}
1474+
if (bcastIdx != -1) {
1475+
if (en.value() != inTy.getShape()[bcastIdx]) {
1476+
reduceDims.push_back(en.index());
1477+
assert(inTy.getShape()[bcastIdx] == 1);
1478+
}
1479+
continue;
1480+
}
1481+
reduceDims.push_back(en.index());
14671482
}
1483+
14681484
if (gutils->width > 1) {
1469-
for (int64_t i = 0; i < reduceDims.size(); i++) {
1470-
reduceDims[i] += 1;
1485+
for (int i = 0; i < reduceDims.size(); i++) {
1486+
reduceDims[i]++;
14711487
}
14721488
}
14731489
getI64Attr(builder, reduceDims);
@@ -1503,11 +1519,33 @@ def BroadcastDimensionsToInversePermutation : GlobalExpr</*needsprimal*/0, /*nee
15031519
}]>;
15041520

15051521
def InsertDeletedReduceDimsType : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1522+
SmallVector<int64_t> reduceDims;
15061523
auto outTy = cast<RankedTensorType>(op.getType());
1524+
auto bcastDims = op.getBroadcastDimensions();
1525+
auto inTy = cast<RankedTensorType>(op.getOperand().getType());
15071526
auto outShape = outTy.getShape();
1527+
1528+
for (auto en : llvm::enumerate(outTy.getShape())) {
1529+
ssize_t bcastIdx = -1;
1530+
for (auto en2 : llvm::enumerate(bcastDims)) {
1531+
if (en2.value() == en.index()) {
1532+
bcastIdx = en2.index();
1533+
break;
1534+
}
1535+
}
1536+
if (bcastIdx != -1) {
1537+
if (en.value() != inTy.getShape()[bcastIdx]) {
1538+
reduceDims.push_back(en.index());
1539+
assert(inTy.getShape()[bcastIdx] == 1);
1540+
}
1541+
continue;
1542+
}
1543+
reduceDims.push_back(en.index());
1544+
}
1545+
15081546
SmallVector<int64_t> reshapeShape(outTy.getRank(), -1);
15091547
for (auto [i, sz] : llvm::enumerate(outShape)) {
1510-
if (!llvm::is_contained(op.getBroadcastDimensions(), i)) {
1548+
if (llvm::is_contained(reduceDims, i)) {
15111549
reshapeShape[i] = 1;
15121550
} else {
15131551
reshapeShape[i] = sz;
@@ -1559,12 +1597,5 @@ def : HLODerivative<"BroadcastInDimOp", (Op $x),
15591597
)
15601598
],
15611599
(
1562-
SelectIfActive $x,
1563-
(
1564-
BroadcastInDim
1565-
(ResultTypeWithBatch),
1566-
(Shadow $x),
1567-
(getBroadcastDimensionsWithBatch)
1568-
),
1569-
(HLOConstantFP<"0">)
1600+
BroadcastInDim (ResultTypeWithBatch), (Shadow $x), (getBroadcastDimensionsWithBatch)
15701601
)>;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: enzymexlamlir-opt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --inline --enzyme-hlo-opt %s | FileCheck %s
2+
3+
module {
4+
func.func private @"Const{typeof(slicing)}(Main.slicing)_autodiff"(%arg0: tensor<1x4x1xf32>) -> (tensor<f32>, tensor<1x4x1xf32>) {
5+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
6+
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
7+
%0 = stablehlo.slice %arg0 [0:1, 0:1, 0:1] : (tensor<1x4x1xf32>) -> tensor<1x1x1xf32>
8+
%1 = stablehlo.reshape %0 : (tensor<1x1x1xf32>) -> tensor<1xf32>
9+
%2 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<1xf32>) -> tensor<3xf32>
10+
%3 = stablehlo.multiply %2, %cst_0 : tensor<3xf32>
11+
%4 = stablehlo.multiply %3, %3 : tensor<3xf32>
12+
%5 = stablehlo.reduce(%4 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<3xf32>, tensor<f32>) -> tensor<f32>
13+
return %5, %arg0 : tensor<f32>, tensor<1x4x1xf32>
14+
}
15+
func.func @main(%arg0: tensor<1x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>) {
16+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
17+
%0:2 = enzyme.autodiff @"Const{typeof(slicing)}(Main.slicing)_autodiff"(%arg0, %cst) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>]} : (tensor<1x4x1xf32>, tensor<f32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>)
18+
return %0#1, %0#0 : tensor<1x4x1xf32>, tensor<1x4x1xf32>
19+
}
20+
}
21+
22+
// CHECK: func.func @main(%arg0: tensor<1x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>) {
23+
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
24+
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:1, 0:1, 0:1] : (tensor<1x4x1xf32>) -> tensor<1x1x1xf32>
25+
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<1x1x1xf32>) -> tensor<1xf32>
26+
// CHECK-NEXT: %2 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<1xf32>) -> tensor<3xf32>
27+
// CHECK-NEXT: %3 = stablehlo.add %2, %2 : tensor<3xf32>
28+
// CHECK-NEXT: %4 = stablehlo.reduce(%3 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<3xf32>, tensor<f32>) -> tensor<f32>
29+
// CHECK-NEXT: %5 = stablehlo.reshape %4 : (tensor<f32>) -> tensor<1x1x1xf32>
30+
// CHECK-NEXT: %6 = stablehlo.pad %5, %cst, low = [0, 0, 0], high = [0, 3, 0], interior = [0, 0, 0] : (tensor<1x1x1xf32>, tensor<f32>) -> tensor<1x4x1xf32>
31+
// CHECK-NEXT: return %6, %arg0 : tensor<1x4x1xf32>, tensor<1x4x1xf32>
32+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)