Skip to content

Commit 97b1dc8

Browse files
authored
Fix c++ import bugs (#1310)
1 parent 0498f6e commit 97b1dc8

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

src/enzyme_ad/jax/Passes/AffineCFG.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2344,7 +2344,7 @@ struct MoveSelectToAffine : public OpRewritePattern<arith::SelectOp> {
23442344

23452345
bool changed = false;
23462346
auto condOp = ifOp.getCondition().getDefiningOp();
2347-
if (isa<AndIOp, OrIOp>(condOp)) {
2347+
if (condOp && isa<AndIOp, OrIOp>(condOp)) {
23482348
// condition, Negated
23492349

23502350
for (auto &opv : condOp->getOpOperands()) {

src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "mlir/IR/IRMapping.h"
3131
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3232

33+
#include "Interfaces/AutoDiffTypeInterface.h"
3334
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
3435

3536
#include "src/enzyme_ad/jax/Dialect/Ops.h"
@@ -2529,6 +2530,7 @@ struct AffineToStableHLORaisingPass
25292530
for (auto arg : operands0) {
25302531

25312532
Attribute attr;
2533+
25322534
if (matchPattern(arg, m_Constant(&attr))) {
25332535
affine::AffineValueMap accessMap(AffineMap::get(arg.getContext()),
25342536
{});
@@ -2539,15 +2541,21 @@ struct AffineToStableHLORaisingPass
25392541
auto unrankedTensorType = RankedTensorType::get({}, ET);
25402542
OpBuilder builder(arg.getContext());
25412543
builder.setInsertionPointToEnd(newBlock);
2542-
auto newConst = builder.create<stablehlo::ConstantOp>(
2543-
arg.getLoc(), unrankedTensorType,
2544-
SplatElementsAttr::get(
2545-
unrankedTensorType,
2546-
ArrayRef<Attribute>(
2547-
isIndex ? IntegerAttr::get(
2548-
ET, cast<IntegerAttr>(attr).getValue())
2549-
: attr)));
2550-
auto newVal = newConst.getResult();
2544+
Value newVal;
2545+
if (arg.getDefiningOp<ub::PoisonOp>()) {
2546+
newVal = cast<mlir::enzyme::AutoDiffTypeInterface>(arg.getType())
2547+
.createNullValue(builder, arg.getLoc());
2548+
} else {
2549+
auto newConst = builder.create<stablehlo::ConstantOp>(
2550+
arg.getLoc(), unrankedTensorType,
2551+
SplatElementsAttr::get(
2552+
unrankedTensorType,
2553+
ArrayRef<Attribute>(
2554+
isIndex ? IntegerAttr::get(
2555+
ET, cast<IntegerAttr>(attr).getValue())
2556+
: attr)));
2557+
newVal = newConst.getResult();
2558+
}
25512559
mapping.map(arg, newVal);
25522560
maps[newVal] = accessMap;
25532561
continue;

0 commit comments

Comments
 (0)