Skip to content

Commit c65bd33

Browse files
committed
namespace conflict
1 parent dcf3885 commit c65bd33

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,7 +1559,8 @@ class AutoDiffReduceRev
15591559
auto bc2 = BroadcastInDimOp::create(builder, op.getLoc(),
15601560
oprev.getType(), inDiffe, attr);
15611561

1562-
auto res = SelectOp::create(builder, op.getLoc(), cmp, bc2, zero);
1562+
auto res =
1563+
stablehlo::SelectOp::create(builder, op.getLoc(), cmp, bc2, zero);
15631564
gutils->addToDiffe(op.getInputs()[0], res, builder);
15641565
}
15651566
if (!gutils->isConstantValue(op.getInitValues()[0])) {
@@ -1571,7 +1572,8 @@ class AutoDiffReduceRev
15711572
auto cmp = CompareOp::create(builder, op.getLoc(), ores, oprev,
15721573
ComparisonDirection::EQ);
15731574

1574-
auto res = SelectOp::create(builder, op.getLoc(), cmp, inDiffe, zeroI);
1575+
auto res = stablehlo::SelectOp::create(builder, op.getLoc(), cmp,
1576+
inDiffe, zeroI);
15751577
gutils->addToDiffe(op.getInitValues()[0], res, builder);
15761578
}
15771579
return success();
@@ -4093,7 +4095,7 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
40934095
CustomCallOp::attachInterface<SHLOGenericBatchOpInterface<CustomCallOp>>(
40944096
*context);
40954097
IotaOp::attachInterface<SHLOIotaOpBatchInterface>(*context);
4096-
SelectOp::attachInterface<SHLOSelectOpBatchInterface>(*context);
4098+
stablehlo::SelectOp::attachInterface<SHLOSelectOpBatchInterface>(*context);
40974099
SortOp::attachInterface<SHLOSortOpBatchInterface>(*context);
40984100
GetDimensionSizeOp::attachInterface<SHLOGetDimensionSizeOpBatchInterface>(
40994101
*context);

src/enzyme_ad/jax/Passes/AffineCFG.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ bool isValidSymbolInt(Operation *defOp, bool recur, Region *scope) {
7575
return true;
7676

7777
if (recur) {
78-
if (isa<SelectOp, IndexCastOp, IndexCastUIOp, AddIOp, MulIOp, DivSIOp,
79-
DivUIOp, RemSIOp, RemUIOp, SubIOp, CmpIOp, TruncIOp, ExtUIOp,
80-
ExtSIOp>(defOp))
78+
if (isa<arith::SelectOp, IndexCastOp, IndexCastUIOp, AddIOp, MulIOp,
79+
DivSIOp, DivUIOp, RemSIOp, RemUIOp, SubIOp, CmpIOp, TruncIOp,
80+
ExtUIOp, ExtSIOp>(defOp))
8181
if (llvm::all_of(defOp->getOperands(), [&](Value v) {
8282
bool b = isValidSymbolInt(v, recur, scope);
8383
// if (!b)
@@ -1109,7 +1109,7 @@ struct SimplfyIntegerCastMath : public OpRewritePattern<IndexCastOp> {
11091109
iadd.getOperand(1)));
11101110
return success();
11111111
}
1112-
if (auto iadd = op.getOperand().getDefiningOp<SelectOp>()) {
1112+
if (auto iadd = op.getOperand().getDefiningOp<arith::SelectOp>()) {
11131113
PatternRewriter b(rewriter);
11141114
setLocationAfter(b, iadd.getTrueValue());
11151115
PatternRewriter b2(rewriter);
@@ -1125,8 +1125,9 @@ struct SimplfyIntegerCastMath : public OpRewritePattern<IndexCastOp> {
11251125
iadd.getTrueValue());
11261126
auto falsev = arith::IndexCastOp::create(b2, op.getLoc(),
11271127
op.getType(), iadd.getFalseValue()); cond = b3CmpIOp::create(b2, cmp.getLoc(),
1128-
cmp.getPredicate(), truev, falsev); rewriter.replaceOpWithNewOp<SelectOp>(op,
1129-
cond, truev, falsev); return success();
1128+
cmp.getPredicate(), truev, falsev);
1129+
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, truev, falsev); return
1130+
success();
11301131
}
11311132
}
11321133
}
@@ -1322,7 +1323,7 @@ bool handleMinMax(Value start, SmallVectorImpl<Value> &out, bool &min,
13221323
if (isValidIndex(cur, scope)) {
13231324
out.push_back(cur);
13241325
continue;
1325-
} else if (auto selOp = cur.getDefiningOp<SelectOp>()) {
1326+
} else if (auto selOp = cur.getDefiningOp<arith::SelectOp>()) {
13261327
// UB only has min of operands
13271328
if (auto cmp = selOp.getCondition().getDefiningOp<CmpIOp>()) {
13281329
if (cmp.getLhs() == selOp.getTrueValue() &&
@@ -2574,7 +2575,7 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
25742575
if (isValidIndex(cur, scope)) {
25752576
lbs.push_back(cur);
25762577
continue;
2577-
} else if (auto selOp = cur.getDefiningOp<SelectOp>()) {
2578+
} else if (auto selOp = cur.getDefiningOp<arith::SelectOp>()) {
25782579
// LB only has max of operands
25792580
if (auto cmp = selOp.getCondition().getDefiningOp<CmpIOp>()) {
25802581
if (cmp.getLhs() == selOp.getTrueValue() &&
@@ -2599,7 +2600,7 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
25992600
if (isValidIndex(cur, scope)) {
26002601
ubs.push_back(cur);
26012602
continue;
2602-
} else if (auto selOp = cur.getDefiningOp<SelectOp>()) {
2603+
} else if (auto selOp = cur.getDefiningOp<arith::SelectOp>()) {
26032604
// UB only has min of operands
26042605
if (auto cmp = selOp.getCondition().getDefiningOp<CmpIOp>()) {
26052606
if (cmp.getLhs() == selOp.getTrueValue() &&

0 commit comments

Comments
 (0)