@@ -36,12 +36,38 @@ struct ExtractOpConversion : public OpConversionPattern<enzyme::ExtractOp> {
3636 LogicalResult
3737 matchAndRewrite (enzyme::ExtractOp op, OpAdaptor adaptor,
3838 ConversionPatternRewriter &rewriter) const override {
39+
40+ auto inTy = op.getInput ().getType ();
3941 auto outTy = op.getOutput ().getType ();
40- // stablehlo always has tensor type
4142 auto outRankTy = dyn_cast<RankedTensorType>(outTy);
42- auto rank = outRankTy.getRank ();
43- return failure ();
44- // stablehlo.dynamic_slice op
43+ // stablehlo always has tensor type
44+ auto inRankTy = dyn_cast<RankedTensorType>(inTy);
45+ auto ndims = inRankTy.getRank (); // is atleast 1
46+
47+ if (ndims < 1 )
48+ return failure ();
49+
50+ // dynamic_slice followed by reshape
51+ auto i64Ty = IntegerType::get (rewriter.getContext (), 64 );
52+ auto tensor0i64Ty = RankedTensorType::get ({}, i64Ty);
53+ auto zero = rewriter.create <stablehlo::ConstantOp>(
54+ op.getLoc (), rewriter.getZeroAttr (tensor0i64Ty));
55+
56+ SmallVector<Value> dynamicSliceStartSlices (ndims, zero);
57+ dynamicSliceStartSlices[0 ] = op.getIndex (); // assume its legal for no
58+
59+ SmallVector<int64_t > localRetShape = {1 };
60+ localRetShape.append (outRankTy.getShape ().begin (),
61+ outRankTy.getShape ().end ());
62+ ;
63+ auto slicedOut = rewriter.create <stablehlo::DynamicSliceOp>(
64+ op->getLoc (), op.getInput (), dynamicSliceStartSlices, localRetShape);
65+
66+ // reshape slicedOut to our final Op
67+ rewriter.replaceOpWithNewOp <stablehlo::ReshapeOp>(op, op->getLoc (), outTy,
68+ slicedOut);
69+
70+ return success ();
4571 }
4672};
4773
@@ -67,8 +93,8 @@ struct ConcatOpConversion : public OpConversionPattern<enzyme::ConcatOp> {
6793 SmallVector<int64_t > newInShape = {1 };
6894 newInShape.append (inShape.begin (), inShape.end ());
6995 auto newInTy = inRankTy.clone (newInShape);
70- Value newInput = rewriter. create <stablehlo::ReshapeOp>(
71- op->getLoc (), newInTy, in, op-> getAttrs () );
96+ Value newInput =
97+ rewriter. create <stablehlo::ReshapeOp>( op->getLoc (), newInTy, in);
7298 expandedInputs.push_back (newInput);
7399 }
74100
@@ -78,6 +104,7 @@ struct ConcatOpConversion : public OpConversionPattern<enzyme::ConcatOp> {
78104 return success ();
79105 }
80106};
107+
81108struct EnzymeBatchToStableHLOPass
82109 : public enzyme::impl::EnzymeBatchToStableHLOPassBase<
83110 EnzymeBatchToStableHLOPass> {
0 commit comments