Skip to content

Commit a1eb9a8

Browse files
committed
fix: check valueAttr is a ElementsAttr
1 parent 095cdf6 commit a1eb9a8

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/enzyme_ad/jax/Passes/ArithRaising.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,16 @@ struct ArithRaisingPass
293293
op->walk([=](arith::ConstantOp constOp) {
294294
if (!use_stablehlo || !isa<RankedTensorType>(constOp.getType()))
295295
return;
296-
auto CT = constOp.getType();
297-
if (isa<TensorType>(CT)) {
298-
OpBuilder builder(constOp);
299-
Value newConstOp = builder.create<stablehlo::ConstantOp>(
300-
constOp.getLoc(), constOp.getValueAttr());
301-
constOp.replaceAllUsesWith(newConstOp);
302-
constOp.erase();
303-
}
296+
297+
auto valueAttr = constOp.getValueAttr();
298+
if (!isa<ElementsAttr>(valueAttr))
299+
return;
300+
301+
OpBuilder builder(constOp);
302+
Value newConstOp =
303+
builder.create<stablehlo::ConstantOp>(constOp.getLoc(), valueAttr);
304+
constOp.replaceAllUsesWith(newConstOp);
305+
constOp.erase();
304306
});
305307
op->walk([=](arith::FPToSIOp addOp) {
306308
if (!use_stablehlo || !isa<RankedTensorType>(addOp->getResultTypes()[0]))

0 commit comments

Comments
 (0)