Skip to content

Commit 319852b

Browse files
authored
fix: guaranteed analysis bfs queue was busted (#1538)
* fix: guaranteed analysis bfs queue was busted * chore: fmt * fix: try making things deterministic * chore: bazel format * fix: address review comments * fix: revert to SmallPtrSet * chore: run fmt * fix: reuse found value
1 parent 4b25005 commit 319852b

File tree

5 files changed

+48
-14
lines changed

5 files changed

+48
-14
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ cc_library(
693693
"@llvm-project//mlir:Support",
694694
"@llvm-project//mlir:ViewLikeInterface",
695695
"@stablehlo//:base",
696+
"@stablehlo//:chlo_ops",
696697
"@stablehlo//:stablehlo_ops",
697698
],
698699
)

src/enzyme_ad/jax/Utils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/ADT/SetVector.h"
3131
#include "llvm/Support/Debug.h"
3232

33+
#include "stablehlo/dialect/ChloOps.h"
3334
#include "stablehlo/dialect/StablehloOps.h"
3435

3536
#include <set>
@@ -885,6 +886,10 @@ NonNegativeResultAnalysis::State NonNegativeResultAnalysis::localGuaranteed(
885886
return State::GUARANTEED;
886887
}
887888

889+
if (isa<chlo::ErfInvOp>(op)) {
890+
return State::NOTGUARANTEED;
891+
}
892+
888893
// Any non-negative operation that produces a non-negative result
889894
// Here we recur on the rhs, as that is more likely to be a constant.
890895
if (isa<stablehlo::MaxOp>(op)) {

src/enzyme_ad/jax/Utils.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "mlir/IR/Types.h"
88
#include "llvm/ADT/APFloat.h"
99
#include "llvm/ADT/DenseMap.h"
10+
#include "llvm/ADT/MapVector.h"
11+
#include "llvm/ADT/SetVector.h"
1012
#include "llvm/ADT/SmallVector.h"
1113

1214
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -354,7 +356,7 @@ template <typename Child> class GuaranteedResultAnalysisBase {
354356

355357
// Map of operations we have seen before. The target of the map[o] is a list
356358
// of sub-queries, that if all true prove that `o` is no-nan.
357-
DenseMap<Operation *, SmallPtrSet<Operation *, 2>> seen;
359+
llvm::MapVector<Operation *, llvm::SmallPtrSet<Operation *, 2>> seen;
358360

359361
// Inverse of seen. A map of operations `p` we still need to prove, to a
360362
// list of values that require `p` to be proven.
@@ -454,10 +456,12 @@ template <typename Child> class GuaranteedResultAnalysisBase {
454456
case State::PENDING: {
455457
assert(localtodo.size());
456458
assert(seen.find(cur) == seen.end());
457-
SmallPtrSet<Operation *, 2> set(localtodo.begin(), localtodo.end());
458-
for (auto v : set) {
459+
for (auto v : localtodo) {
459460
reverseSeen[v].push_back(cur);
461+
todo.push_back(v);
460462
}
463+
llvm::SmallPtrSet<Operation *, 2> set(localtodo.begin(),
464+
localtodo.end());
461465
seen[cur] = std::move(set);
462466
break;
463467
}
@@ -475,11 +479,15 @@ template <typename Child> class GuaranteedResultAnalysisBase {
475479
});
476480
}
477481

478-
assert(opCache.find(op) != opCache.end());
479-
rewriter.modifyOpInPlace(op, [&]() {
480-
op->setAttr(attrName, BoolAttr::get(op->getContext(), true));
481-
});
482-
return true;
482+
auto found = opCache.find(op);
483+
if (found != opCache.end()) {
484+
bool guaranteed = found->second;
485+
rewriter.modifyOpInPlace(op, [&]() {
486+
op->setAttr(attrName, BoolAttr::get(op->getContext(), guaranteed));
487+
});
488+
return guaranteed;
489+
}
490+
return false;
483491
}
484492

485493
bool guaranteed(stablehlo::ConstantOp constOp, PatternRewriter &rewriter) {

test/lit_tests/abspositive.mlir

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ func.func @test1() -> tensor<18500xi64> {
99
}
1010

1111
// CHECK: func.func @test1() -> tensor<18500xi64> {
12-
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<18500xi64>
13-
// CHECK-NEXT: %0 = stablehlo.iota dim = 0 : tensor<18500xi64>
12+
// CHECK-NEXT: %c = stablehlo.constant {enzymexla.guaranteed_non_negative = true} dense<1> : tensor<18500xi64>
13+
// CHECK-NEXT: %0 = stablehlo.iota dim = 0 {enzymexla.guaranteed_non_negative = true} : tensor<18500xi64>
1414
// CHECK-NEXT: %1 = stablehlo.add %0, %c {enzymexla.guaranteed_non_negative = true} : tensor<18500xi64>
1515
// CHECK-NEXT: return %1 : tensor<18500xi64>
1616
// CHECK-NEXT: }
@@ -45,7 +45,27 @@ func.func @test4(%arg0: tensor<12xf64>) -> tensor<4x3xf64> {
4545
}
4646

4747
// CHECK: func.func @test4(%arg0: tensor<12xf64>) -> tensor<4x3xf64> {
48-
// CHECK-NEXT: %0 = stablehlo.multiply %arg0, %arg0 : tensor<12xf64>
48+
// CHECK-NEXT: %0 = stablehlo.multiply %arg0, %arg0 {enzymexla.guaranteed_non_negative = true} : tensor<12xf64>
4949
// CHECK-NEXT: %1 = stablehlo.reshape %0 {enzymexla.guaranteed_non_negative = true} : (tensor<12xf64>) -> tensor<4x3xf64>
5050
// CHECK-NEXT: return %1 : tensor<4x3xf64>
5151
// CHECK-NEXT: }
52+
53+
func.func @test5(%arg0: tensor<12xf64>, %arg1: tensor<12xf64>, %arg2: tensor<12xf64>) -> tensor<12xf64> {
54+
// CHECK: chlo.erf_inv %arg0 {enzymexla.guaranteed_non_negative = false}
55+
%0 = chlo.erf_inv %arg0 : tensor<12xf64> -> tensor<12xf64>
56+
%1 = stablehlo.add %arg1, %arg2 : tensor<12xf64>
57+
%2 = stablehlo.add %0, %1 : tensor<12xf64>
58+
// CHECK: stablehlo.abs
59+
%3 = stablehlo.abs %2 : tensor<12xf64>
60+
return %3 : tensor<12xf64>
61+
}
62+
63+
func.func @test6(%arg0: tensor<12xf64>, %arg1: tensor<12xf64>, %arg2: tensor<12xf64>) -> tensor<12xf64> {
64+
%cst = stablehlo.constant dense<-1.000000e+00> : tensor<12xf64>
65+
%0 = stablehlo.multiply %arg0, %cst : tensor<12xf64>
66+
%1 = stablehlo.add %arg1, %arg2 : tensor<12xf64>
67+
%2 = stablehlo.add %0, %1 : tensor<12xf64>
68+
// CHECK: stablehlo.abs
69+
%3 = stablehlo.abs %2 : tensor<12xf64>
70+
return %3 : tensor<12xf64>
71+
}

test/lit_tests/raising/affine_to_stablehlo6.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ module {
111111
// CHECK-NEXT: %c_4 = stablehlo.constant dense<-92> : tensor<185xi64>
112112
// CHECK-NEXT: %0 = stablehlo.iota dim = 0 : tensor<185xi64>
113113
// CHECK-NEXT: %1 = stablehlo.add %0, %c_4 : tensor<185xi64>
114-
// CHECK-NEXT: %2 = stablehlo.convert %1 {enzymexla.guaranteed_finite = true, enzymexla.guaranteed_non_negative = false} : (tensor<185xi64>) -> tensor<185xf64>
114+
// CHECK-NEXT: %2 = stablehlo.convert %1 {enzymexla.guaranteed_finite = true, enzymexla.guaranteed_no_nan = true, enzymexla.guaranteed_non_negative = false} : (tensor<185xi64>) -> tensor<185xf64>
115115
// CHECK-NEXT: %3 = stablehlo.add %cst_3, %2 {enzymexla.guaranteed_finite = true} : tensor<185xf64>
116116
// CHECK-NEXT: %4 = stablehlo.multiply %3, %cst_0 : tensor<185xf64>
117117
// CHECK-NEXT: %5 = stablehlo.cosine %4 : tensor<185xf64>
@@ -128,7 +128,7 @@ module {
128128
// CHECK-NEXT: %16 = stablehlo.slice %arg3 [0:1] : (tensor<186xf64>) -> tensor<1xf64>
129129
// CHECK-NEXT: %17 = stablehlo.concatenate %16, %6, dim = 0 : (tensor<1xf64>, tensor<185xf64>) -> tensor<186xf64>
130130
// CHECK-NEXT: %18 = stablehlo.add %0, %c_2 : tensor<185xi64>
131-
// CHECK-NEXT: %19 = stablehlo.convert %18 {enzymexla.guaranteed_finite = true} : (tensor<185xi64>) -> tensor<185xf64>
131+
// CHECK-NEXT: %19 = stablehlo.convert %18 {enzymexla.guaranteed_finite = true, enzymexla.guaranteed_no_nan = true} : (tensor<185xi64>) -> tensor<185xf64>
132132
// CHECK-NEXT: %20 = stablehlo.multiply %19, %cst_0 : tensor<185xf64>
133133
// CHECK-NEXT: %21 = stablehlo.sine %20 : tensor<185xf64>
134134
// CHECK-NEXT: %22 = stablehlo.sine %9 : tensor<185xf64>
@@ -138,7 +138,7 @@ module {
138138
// CHECK-NEXT: %26 = stablehlo.concatenate %25, %24, dim = 0 : (tensor<1xf64>, tensor<185xf64>) -> tensor<186xf64>
139139
// CHECK-NEXT: %27 = stablehlo.sine %4 : tensor<185xf64>
140140
// CHECK-NEXT: %28 = stablehlo.add %0, %c : tensor<185xi64>
141-
// CHECK-NEXT: %29 = stablehlo.convert %28 {enzymexla.guaranteed_finite = true, enzymexla.guaranteed_non_negative = false} : (tensor<185xi64>) -> tensor<185xf64>
141+
// CHECK-NEXT: %29 = stablehlo.convert %28 {enzymexla.guaranteed_finite = true, enzymexla.guaranteed_no_nan = true, enzymexla.guaranteed_non_negative = false} : (tensor<185xi64>) -> tensor<185xf64>
142142
// CHECK-NEXT: %30 = stablehlo.add %cst_3, %29 {enzymexla.guaranteed_finite = true} : tensor<185xf64>
143143
// CHECK-NEXT: %31 = stablehlo.multiply %30, %cst_0 : tensor<185xf64>
144144
// CHECK-NEXT: %32 = stablehlo.sine %31 : tensor<185xf64>

0 commit comments

Comments
 (0)