3030#include " mlir/IR/IRMapping.h"
3131#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
3232
33+ #include " Interfaces/AutoDiffTypeInterface.h"
3334#include " mlir/Dialect/GPU/IR/GPUDialect.h"
3435
3536#include " src/enzyme_ad/jax/Dialect/Ops.h"
@@ -2529,6 +2530,7 @@ struct AffineToStableHLORaisingPass
25292530 for (auto arg : operands0) {
25302531
25312532 Attribute attr;
2533+
25322534 if (matchPattern (arg, m_Constant (&attr))) {
25332535 affine::AffineValueMap accessMap (AffineMap::get (arg.getContext ()),
25342536 {});
@@ -2539,15 +2541,21 @@ struct AffineToStableHLORaisingPass
25392541 auto unrankedTensorType = RankedTensorType::get ({}, ET);
25402542 OpBuilder builder (arg.getContext ());
25412543 builder.setInsertionPointToEnd (newBlock);
2542- auto newConst = builder.create <stablehlo::ConstantOp>(
2543- arg.getLoc (), unrankedTensorType,
2544- SplatElementsAttr::get (
2545- unrankedTensorType,
2546- ArrayRef<Attribute>(
2547- isIndex ? IntegerAttr::get (
2548- ET, cast<IntegerAttr>(attr).getValue ())
2549- : attr)));
2550- auto newVal = newConst.getResult ();
2544+ Value newVal;
2545+ if (arg.getDefiningOp <ub::PoisonOp>()) {
2546+ newVal = cast<mlir::enzyme::AutoDiffTypeInterface>(arg.getType ())
2547+ .createNullValue (builder, arg.getLoc ());
2548+ } else {
2549+ auto newConst = builder.create <stablehlo::ConstantOp>(
2550+ arg.getLoc (), unrankedTensorType,
2551+ SplatElementsAttr::get (
2552+ unrankedTensorType,
2553+ ArrayRef<Attribute>(
2554+ isIndex ? IntegerAttr::get (
2555+ ET, cast<IntegerAttr>(attr).getValue ())
2556+ : attr)));
2557+ newVal = newConst.getResult ();
2558+ }
25512559 mapping.map (arg, newVal);
25522560 maps[newVal] = accessMap;
25532561 continue ;
0 commit comments