2626#include " mlir/Transforms/DialectConversion.h"
2727#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2828
29+ #include " mlir/Interfaces/InferTypeOpInterface.h"
30+
2931#include < numeric>
3032#include < type_traits>
3133
@@ -34,7 +36,7 @@ using namespace mlir::tosa;
3436
3537static mlir::Value applyPad (Location loc, Value input, ArrayRef<int64_t > pad,
3638 TypedAttr padAttr, OpBuilder &rewriter) {
37- // Input should be padded if necessary.
39+ // Input should be padded only if necessary.
3840 if (llvm::all_of (pad, [](int64_t p) { return p == 0 ; }))
3941 return input;
4042
@@ -47,7 +49,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
4749 SmallVector<int64_t , 4 > paddedShape;
4850 SmallVector<OpFoldResult, 8 > lowIndices;
4951 SmallVector<OpFoldResult, 8 > highIndices;
50- for (int i = 0 , s = inputShape.size (); i < s; i++ ) {
52+ for (size_t i : llvm::seq ( inputShape.size ()) ) {
5153 auto lowPad = pad[i * 2 ];
5254 auto highPad = pad[i * 2 + 1 ];
5355 if (ShapedType::isDynamic (inputShape[i]))
@@ -131,20 +133,19 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
131133
132134static mlir::Value reifyConstantDim (int64_t attr,
133135 ImplicitLocOpBuilder &builder) {
134- return builder.createOrFold <arith::IndexCastOp>(
135- builder.getIndexType (),
136- builder.create <arith::ConstantOp>(builder.getI64IntegerAttr (attr)));
136+ return builder.create <arith::ConstantIndexOp>(attr);
137137}
138138
139139// Calculating the output width/height using the formula:
140140// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
141141// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
142142
143- static mlir::Value getConvOutputDim (Location loc, Value inputDim,
144- int64_t padBeforeAttr, int64_t padAfterAttr,
145- Value kernelDim, int64_t strideAttr,
146- int64_t dilationAttr, Type inputETy,
147- OpBuilder &rewriter) {
143+ static mlir::Value getConvOrPoolOutputDim (Location loc, Value inputDim,
144+ int64_t padBeforeAttr,
145+ int64_t padAfterAttr, Value kernelDim,
146+ int64_t strideAttr,
147+ int64_t dilationAttr,
148+ OpBuilder &rewriter) {
148149 ImplicitLocOpBuilder builder (loc, rewriter);
149150 auto one = rewriter.create <arith::ConstantOp>(
150151 loc, IntegerAttr::get (inputDim.getType (), 1 ));
@@ -171,7 +172,6 @@ static SmallVector<Value> inferDynamicDimsForConv(
171172 ArrayRef<int64_t > dilationAttr, ArrayRef<int64_t > inputSizeDims,
172173 ArrayRef<int64_t > kernelSizeDims, OpBuilder &rewriter) {
173174 ShapedType inputTy = cast<ShapedType>(input.getType ());
174- Type inputETy = inputTy.getElementType ();
175175 int64_t inputRank = inputTy.getRank ();
176176
177177 SmallVector<Value> dynDims;
@@ -190,8 +190,8 @@ static SmallVector<Value> inferDynamicDimsForConv(
190190 rewriter.create <tensor::DimOp>(loc, weight, kernelDim);
191191 // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
192192 dynDims[inputDim] =
193- getConvOutputDim (loc, initDynDim, padTop, padBottom, kernelDynDim ,
194- stride, dilation, inputETy , rewriter);
193+ getConvOrPoolOutputDim (loc, initDynDim, padTop, padBottom,
194+ kernelDynDim, stride, dilation , rewriter);
195195 }
196196 }
197197
@@ -685,20 +685,61 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
685685public:
686686 using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
687687
688+ // Compute the dynamic output sizes of the maxpool operation.
689+ static SmallVector<Value>
690+ computeDynamicOutputSizes (tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
691+ TensorType resultTy = op.getType ();
692+ Location loc = op.getLoc ();
693+
694+ TypedValue<TensorType> input = op.getInput ();
695+ ArrayRef<int64_t > kernel = op.getKernel ();
696+ ArrayRef<int64_t > pad = op.getPad ();
697+ ArrayRef<int64_t > stride = op.getStride ();
698+
699+ SmallVector<Value> dynamicDims;
700+
701+ // Batch dimension
702+ if (resultTy.isDynamicDim (0 ))
703+ dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 0 ));
704+
705+ // Height/width dimensions
706+ for (int64_t dim : {1 , 2 }) {
707+ if (!resultTy.isDynamicDim (dim))
708+ continue ;
709+
710+ // Index into the attribute arrays
711+ int64_t index = dim - 1 ;
712+
713+ // Input height/width
714+ Value ihw = rewriter.create <tensor::DimOp>(loc, input, dim);
715+
716+ // Kernel height/width
717+ Value khw = rewriter.create <arith::ConstantIndexOp>(loc, kernel[index]);
718+
719+ // Output height/width
720+ Value ohw = getConvOrPoolOutputDim (loc, ihw, pad[index * 2 ],
721+ pad[index * 2 + 1 ], khw, stride[index],
722+ /* dilationAttr=*/ 1 , rewriter);
723+ dynamicDims.push_back (ohw);
724+ }
725+
726+ // Channel dimension
727+ if (resultTy.isDynamicDim (3 ))
728+ dynamicDims.push_back (rewriter.create <tensor::DimOp>(loc, input, 3 ));
729+
730+ return dynamicDims;
731+ }
732+
688733 LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
689734 PatternRewriter &rewriter) const final {
690735 Location loc = op.getLoc ();
691- Value input = op.getInput ();
692- ShapedType inputTy = cast<ShapedType>( input.getType () );
736+ TypedValue<TensorType> input = op.getInput ();
737+ ShapedType inputTy = input.getType ();
693738
694- ShapedType resultTy = cast<ShapedType>( op.getType () );
739+ ShapedType resultTy = op.getType ();
695740 Type resultETy = inputTy.getElementType ();
696741
697- auto dynamicDimsOr =
698- checkHasDynamicBatchDims (rewriter, op, {input, op.getOutput ()});
699- if (!dynamicDimsOr.has_value ())
700- return failure ();
701- SmallVector<Value> dynamicDims = *dynamicDimsOr;
742+ SmallVector<Value> dynamicDims = computeDynamicOutputSizes (op, rewriter);
702743
703744 // Determine what the initial value needs to be for the max pool op.
704745 TypedAttr initialAttr;
@@ -721,6 +762,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
721762 pad.resize (2 , 0 );
722763 llvm::append_range (pad, op.getPad ());
723764 pad.resize (pad.size () + 2 , 0 );
765+
724766 Value paddedInput = applyPad (loc, input, pad, initialAttr, rewriter);
725767
726768 Value initialValue = rewriter.create <arith::ConstantOp>(loc, initialAttr);
@@ -736,9 +778,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
736778 loc, resultTy.getShape (), resultTy.getElementType (), dynamicDims);
737779
738780 Value filledEmptyTensor =
739- rewriter
740- .create <linalg::FillOp>(loc, ValueRange{initialValue},
741- ValueRange{emptyTensor})
781+ rewriter.create <linalg::FillOp>(loc, initialValue, emptyTensor)
742782 .result ();
743783
744784 Value fakeWindowDims =
0 commit comments