Skip to content

Commit 1406731

Browse files
committed
just add conversion for Extract assuming index is 0-dim tensor
1 parent af2bfe5 commit 1406731

File tree

1 file changed

+33
-6
lines changed

1 file changed

+33
-6
lines changed

src/enzyme_ad/jax/Passes/EnzymeBatchToStableHLOPass.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,38 @@ struct ExtractOpConversion : public OpConversionPattern<enzyme::ExtractOp> {
3636
LogicalResult
3737
matchAndRewrite(enzyme::ExtractOp op, OpAdaptor adaptor,
3838
ConversionPatternRewriter &rewriter) const override {
39+
40+
auto inTy = op.getInput().getType();
3941
auto outTy = op.getOutput().getType();
40-
// stablehlo always has tensor type
4142
auto outRankTy = dyn_cast<RankedTensorType>(outTy);
42-
auto rank = outRankTy.getRank();
43-
return failure();
44-
// stablehlo.dynamic_slice op
43+
// stablehlo always has tensor type
44+
auto inRankTy = dyn_cast<RankedTensorType>(inTy);
45+
auto ndims = inRankTy.getRank(); // is atleast 1
46+
47+
if (ndims < 1)
48+
return failure();
49+
50+
// dynamic_slice followed by reshape
51+
auto i64Ty = IntegerType::get(rewriter.getContext(), 64);
52+
auto tensor0i64Ty = RankedTensorType::get({}, i64Ty);
53+
auto zero = rewriter.create<stablehlo::ConstantOp>(
54+
op.getLoc(), rewriter.getZeroAttr(tensor0i64Ty));
55+
56+
SmallVector<Value> dynamicSliceStartSlices(ndims, zero);
57+
dynamicSliceStartSlices[0] = op.getIndex(); // assume its legal for no
58+
59+
SmallVector<int64_t> localRetShape = {1};
60+
localRetShape.append(outRankTy.getShape().begin(),
61+
outRankTy.getShape().end());
62+
;
63+
auto slicedOut = rewriter.create<stablehlo::DynamicSliceOp>(
64+
op->getLoc(), op.getInput(), dynamicSliceStartSlices, localRetShape);
65+
66+
// reshape slicedOut to our final Op
67+
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op->getLoc(), outTy,
68+
slicedOut);
69+
70+
return success();
4571
}
4672
};
4773

@@ -67,8 +93,8 @@ struct ConcatOpConversion : public OpConversionPattern<enzyme::ConcatOp> {
6793
SmallVector<int64_t> newInShape = {1};
6894
newInShape.append(inShape.begin(), inShape.end());
6995
auto newInTy = inRankTy.clone(newInShape);
70-
Value newInput = rewriter.create<stablehlo::ReshapeOp>(
71-
op->getLoc(), newInTy, in, op->getAttrs());
96+
Value newInput =
97+
rewriter.create<stablehlo::ReshapeOp>(op->getLoc(), newInTy, in);
7298
expandedInputs.push_back(newInput);
7399
}
74100

@@ -78,6 +104,7 @@ struct ConcatOpConversion : public OpConversionPattern<enzyme::ConcatOp> {
78104
return success();
79105
}
80106
};
107+
81108
struct EnzymeBatchToStableHLOPass
82109
: public enzyme::impl::EnzymeBatchToStableHLOPassBase<
83110
EnzymeBatchToStableHLOPass> {

0 commit comments

Comments
 (0)