Skip to content

Commit 7c8b87a

Browse files
authored
feat: check only specific arguments of operands + cse_select (#1913)
* feat: check only specific arguments of operands * feat: select cse * test: partial operands check
1 parent 62ff112 commit 7c8b87a

File tree

6 files changed

+90
-89
lines changed

6 files changed

+90
-89
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29929,8 +29929,8 @@ struct EnzymeHLOOptPass
2992929929
CSE<stablehlo::NegOp>, CSE<stablehlo::AbsOp>,
2993029930
CSE<enzymexla::RotateOp>, CSE<enzymexla::WrapOp>,
2993129931
CSE<enzymexla::ExtendOp>, CSEIota, CSE<stablehlo::CompareOp>,
29932-
CSE<stablehlo::GatherOp>, CSE<stablehlo::ScatterOp>>(
29933-
context, PatternBenefit(65000));
29932+
CSE<stablehlo::GatherOp>, CSE<stablehlo::ScatterOp>,
29933+
CSE<stablehlo::SelectOp>>(context, PatternBenefit(65000));
2993429934
}
2993529935

2993629936
if (passses & 256)

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,10 @@ def ApplyCSEComparePatterns : EnzymeHLOPatternOp<
878878
"cse_compare"> {
879879
let patterns = ["CSE<stablehlo::CompareOp>"];
880880
}
881+
def ApplyCSESelectPatterns : EnzymeHLOPatternOp<
882+
"cse_select"> {
883+
let patterns = ["CSE<stablehlo::SelectOp>"];
884+
}
881885

882886
def CompareAbs : EnzymeHLOPatternOp<
883887
"compare_abs"> {

src/enzyme_ad/jax/Utils.cpp

Lines changed: 37 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include <iterator>
4040
#include <mlir/IR/BuiltinAttributes.h>
4141
#include <mlir/IR/Value.h>
42+
#include <optional>
4243
#include <set>
4344

4445
using namespace mlir;
@@ -765,19 +766,21 @@ NoNanResultAnalysis::State NoNanResultAnalysis::localGuaranteed(
765766
}
766767

767768
bool recursiveCheck = false;
769+
SmallVector<Value> operandsToCheck;
768770

769771
if (isa<stablehlo::SliceOp, stablehlo::ConcatenateOp,
770772
stablehlo::BroadcastInDimOp, stablehlo::ReshapeOp,
771773
stablehlo::TransposeOp>(op)) {
772774
// data movement ops
773775
recursiveCheck = true;
776+
operandsToCheck.append(op->getOperands().begin(), op->getOperands().end());
774777
} else if (isa<stablehlo::AbsOp, stablehlo::ExpOp, stablehlo::ConvertOp,
775778
stablehlo::CompareOp, stablehlo::TanhOp, stablehlo::LogisticOp,
776779
stablehlo::FloorOp, stablehlo::CeilOp>(op)) {
777780
// elementwise ops that are no-nan if all operands are not nan
778781
recursiveCheck = true;
782+
operandsToCheck.append(op->getOperands().begin(), op->getOperands().end());
779783
} else if (isa<stablehlo::AddOp, stablehlo::SubtractOp>(op)) {
780-
781784
// If any one of the operands is a Inf, the result is Inf. If both are Inf,
782785
// the result is NaN.
783786
auto lhsFinite =
@@ -790,13 +793,15 @@ NoNanResultAnalysis::State NoNanResultAnalysis::localGuaranteed(
790793
}
791794

792795
recursiveCheck = true;
796+
operandsToCheck.append(op->getOperands().begin(), op->getOperands().end());
793797
} else if (isa<stablehlo::SineOp, stablehlo::CosineOp>(op)) {
794798

795799
if (!finiteResultAnalysis->guaranteed(op->getOperand(0), rewriter)) {
796800
return State::NOTGUARANTEED;
797801
}
798802

799803
recursiveCheck = true;
804+
operandsToCheck.append(op->getOperands().begin(), op->getOperands().end());
800805
} else if (auto mulOp = dyn_cast<stablehlo::MulOp>(op)) {
801806
// if lhs is Inf & rhs is 0 or the other way around, mul is going to be NaN
802807

@@ -808,38 +813,19 @@ NoNanResultAnalysis::State NoNanResultAnalysis::localGuaranteed(
808813
}
809814

810815
recursiveCheck = true;
811-
} else if (isa<mlir::stablehlo::SelectOp>(op)) {
816+
operandsToCheck.append(op->getOperands().begin(), op->getOperands().end());
817+
} else if (isa<stablehlo::SelectOp>(op)) {
818+
recursiveCheck = true;
819+
operandsToCheck.push_back(op->getOperand(1));
820+
operandsToCheck.push_back(op->getOperand(2));
821+
} else if (isa<stablehlo::DynamicSliceOp, stablehlo::DynamicUpdateSliceOp>(
822+
op)) {
812823
recursiveCheck = true;
824+
operandsToCheck.push_back(op->getOperand(0));
813825
}
814826

815827
if (recursiveCheck) {
816-
bool allOperandsGuaranteed = true;
817-
for (auto operand : op->getOperands()) {
818-
if (auto TT = dyn_cast<TensorType>(operand.getType())) {
819-
if (TT.getElementType().isInteger())
820-
continue;
821-
}
822-
823-
{
824-
auto found = valueCache.find(operand);
825-
if (found != valueCache.end()) {
826-
if (found->second) {
827-
continue;
828-
} else {
829-
return State::NOTGUARANTEED;
830-
}
831-
}
832-
}
833-
834-
localtodo.push_back(operand);
835-
allOperandsGuaranteed = false;
836-
}
837-
838-
if (allOperandsGuaranteed) {
839-
return State::GUARANTEED;
840-
} else {
841-
return State::PENDING;
842-
}
828+
return recursivelyCheckOperands(localtodo, operandsToCheck, true);
843829
} else {
844830
return State::NOTGUARANTEED;
845831
}
@@ -878,54 +864,36 @@ FiniteResultAnalysis::State FiniteResultAnalysis::localGuaranteed(
878864
}
879865

880866
bool recursiveCheck = false;
867+
SmallVector<Value> operandsToCheck;
881868

882869
if (isa<stablehlo::SliceOp, stablehlo::ConcatenateOp,
883870
stablehlo::BroadcastInDimOp, stablehlo::ReshapeOp,
884871
stablehlo::TransposeOp>(op)) {
885872
// data movement ops
886873
recursiveCheck = true;
874+
operandsToCheck.append(op->getOperands().begin(), op->getOperands().end());
887875
} else if (isa<stablehlo::AddOp, stablehlo::SubtractOp, stablehlo::MulOp,
888876
stablehlo::AbsOp, stablehlo::ExpOp, stablehlo::ConvertOp,
889877
stablehlo::CompareOp>(op)) {
890878
// if both finite [but possibly nan], the result is finite, or nan
891-
892879
recursiveCheck = true;
880+
operandsToCheck.append(op->getOperands().begin(), op->getOperands().end());
893881
} else if (isa<stablehlo::TanhOp, stablehlo::LogisticOp, stablehlo::SineOp,
894882
stablehlo::CosineOp>(op)) {
895883
// guaranteed finite or nan result, always
896884
return State::GUARANTEED;
897-
} else if (isa<mlir::stablehlo::SelectOp>(op)) {
885+
} else if (isa<stablehlo::SelectOp>(op)) {
898886
recursiveCheck = true;
887+
operandsToCheck.push_back(op->getOperand(1));
888+
operandsToCheck.push_back(op->getOperand(2));
889+
} else if (isa<stablehlo::DynamicSliceOp, stablehlo::DynamicUpdateSliceOp>(
890+
op)) {
891+
recursiveCheck = true;
892+
operandsToCheck.push_back(op->getOperand(0));
899893
}
900894

901895
if (recursiveCheck) {
902-
bool allOperandsGuaranteed = true;
903-
for (auto operand : op->getOperands()) {
904-
if (auto TT = dyn_cast<TensorType>(operand.getType())) {
905-
if (TT.getElementType().isInteger())
906-
continue;
907-
}
908-
909-
{
910-
auto found = valueCache.find(operand);
911-
if (found != valueCache.end()) {
912-
if (found->second) {
913-
continue;
914-
} else {
915-
return State::NOTGUARANTEED;
916-
}
917-
}
918-
}
919-
920-
localtodo.push_back(operand);
921-
allOperandsGuaranteed = false;
922-
}
923-
924-
if (allOperandsGuaranteed) {
925-
return State::GUARANTEED;
926-
} else {
927-
return State::PENDING;
928-
}
896+
return recursivelyCheckOperands(localtodo, operandsToCheck, true);
929897
} else {
930898
return State::NOTGUARANTEED;
931899
}
@@ -1018,45 +986,27 @@ NonNegativeResultAnalysis::State NonNegativeResultAnalysis::localGuaranteed(
1018986
}
1019987

1020988
bool recursiveCheck = false;
989+
SmallVector<Value> operandsToCheck;
1021990

1022991
if (isa<stablehlo::MinOp, stablehlo::AddOp, stablehlo::MulOp,
1023992
stablehlo::ConcatenateOp, stablehlo::ReshapeOp,
1024993
stablehlo::TransposeOp, stablehlo::SliceOp,
1025-
stablehlo::DynamicUpdateSliceOp, stablehlo::BroadcastInDimOp>(op)) {
994+
stablehlo::BroadcastInDimOp>(op)) {
1026995
// All non-negative operations that produce a non-negative result
1027996
recursiveCheck = true;
1028-
} else if (isa<mlir::stablehlo::SelectOp>(op)) {
997+
operandsToCheck.append(op->getOperands().begin(), op->getOperands().end());
998+
} else if (isa<stablehlo::SelectOp>(op)) {
999+
recursiveCheck = true;
1000+
operandsToCheck.push_back(op->getOperand(1));
1001+
operandsToCheck.push_back(op->getOperand(2));
1002+
} else if (isa<stablehlo::DynamicSliceOp, stablehlo::DynamicUpdateSliceOp>(
1003+
op)) {
10291004
recursiveCheck = true;
1005+
operandsToCheck.push_back(op->getOperand(0));
10301006
}
10311007

10321008
if (recursiveCheck) {
1033-
bool allOperandsGuaranteed = true;
1034-
size_t idx = 0;
1035-
for (auto operand : op->getOperands()) {
1036-
if (idx == 0 && isa<mlir::stablehlo::SelectOp>(op))
1037-
continue;
1038-
idx++;
1039-
1040-
{
1041-
auto found = valueCache.find(operand);
1042-
if (found != valueCache.end()) {
1043-
if (found->second) {
1044-
continue;
1045-
} else {
1046-
return State::NOTGUARANTEED;
1047-
}
1048-
}
1049-
}
1050-
1051-
localtodo.push_back(operand);
1052-
allOperandsGuaranteed = false;
1053-
}
1054-
1055-
if (allOperandsGuaranteed) {
1056-
return State::GUARANTEED;
1057-
} else {
1058-
return State::PENDING;
1059-
}
1009+
return recursivelyCheckOperands(localtodo, operandsToCheck, false);
10601010
} else {
10611011
return State::NOTGUARANTEED;
10621012
}

src/enzyme_ad/jax/Utils.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,37 @@ template <typename Child> class GuaranteedResultAnalysisBase {
698698
return state;
699699
}
700700

701+
protected:
702+
template <typename ItTy>
703+
State recursivelyCheckOperands(SmallVectorImpl<Value> &localtodo,
704+
ItTy operands, bool skipIntegerEltypes) {
705+
assert(!operands.empty() && "expected operands to not be empty");
706+
707+
bool allOperandsGuaranteed = true;
708+
for (auto operand : operands) {
709+
if (skipIntegerEltypes) {
710+
if (auto TT = dyn_cast<TensorType>(operand.getType())) {
711+
if (TT.getElementType().isInteger()) {
712+
continue;
713+
}
714+
}
715+
}
716+
717+
auto found = valueCache.find(operand);
718+
if (found != valueCache.end()) {
719+
if (found->second) {
720+
continue;
721+
}
722+
return State::NOTGUARANTEED;
723+
}
724+
725+
localtodo.push_back(operand);
726+
allOperandsGuaranteed = false;
727+
}
728+
729+
return allOperandsGuaranteed ? State::GUARANTEED : State::PENDING;
730+
}
731+
701732
private:
702733
State
703734
GuaranteedAnalysisResultToState(enzymexla::GuaranteedAnalysisResult val) {

src/enzyme_ad/jax/primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def optimization_passes(
258258
"cse_wrap<16>",
259259
"cse_rotate<16>",
260260
"cse_rotate<16>",
261+
"cse_select<16>",
261262
"concat_concat_axis_swap",
262263
"concat_concat_to_dus",
263264
"broadcast_iota_simplify",

test/lit_tests/abspositive.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,18 @@ func.func @test6(%arg0: tensor<12xf64>, %arg1: tensor<12xf64>, %arg2: tensor<12x
6969
%3 = stablehlo.abs %2 : tensor<12xf64>
7070
return %3 : tensor<12xf64>
7171
}
72+
73+
func.func @test7(%arg0: tensor<12xf64>, %idx: tensor<i32>) -> tensor<2x3xf64> {
74+
%0 = stablehlo.multiply %arg0, %arg0 : tensor<12xf64>
75+
%1 = stablehlo.reshape %0 : (tensor<12xf64>) -> tensor<4x3xf64>
76+
%2 = stablehlo.dynamic_slice %1, %idx, %idx, sizes = [2, 3] : (tensor<4x3xf64>, tensor<i32>, tensor<i32>) -> tensor<2x3xf64>
77+
%3 = stablehlo.abs %2 : tensor<2x3xf64>
78+
return %3 : tensor<2x3xf64>
79+
}
80+
81+
// CHECK: func.func @test7(%arg0: tensor<12xf64>, %arg1: tensor<i32>) -> tensor<2x3xf64> {
82+
// CHECK-NEXT: %0 = stablehlo.multiply %arg0, %arg0 {enzymexla.non_negative = [#enzymexla<guaranteed GUARANTEED>]} : tensor<12xf64>
83+
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<12xf64>) -> tensor<4x3xf64>
84+
// CHECK-NEXT: %2 = stablehlo.dynamic_slice %1, %arg1, %arg1, sizes = [2, 3] {enzymexla.non_negative = [#enzymexla<guaranteed GUARANTEED>]} : (tensor<4x3xf64>, tensor<i32>, tensor<i32>) -> tensor<2x3xf64>
85+
// CHECK-NEXT: return %2 : tensor<2x3xf64>
86+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)