@@ -39,6 +39,280 @@ using namespace mlir::tosa;
3939// Operator Canonicalizers.
4040// ===----------------------------------------------------------------------===//
4141
42+ // ===----------------------------------------------------------------------===//
43+ // Tensor Data Engine Operators.
44+ // ===----------------------------------------------------------------------===//
45+
46+ // Check that the zero point of the tensor and padding operations are aligned.
47+ bool checkMatchingPadConstAndZp (Value padConst, Value zp) {
48+ // Check that padConst is a constant value and a scalar tensor
49+ DenseElementsAttr padConstAttr;
50+ if (!matchPattern (padConst, m_Constant (&padConstAttr)) ||
51+ (padConstAttr.size () != 1 )) {
52+ return false ;
53+ }
54+
55+ // Check that floating point pad is zero
56+ if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57+ float padConstVal = (*padConstFpAttr.begin ()).convertToFloat ();
58+ return padConstVal == 0 .0f ;
59+ }
60+
61+ // Check that the zp and padConst align for the integer (quantized) case
62+ if (auto padConstIntAttr =
63+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
64+ DenseIntElementsAttr zpAttr;
65+ // Check that zp is a constant value and a scalar tensor
66+ if (!matchPattern (zp, m_Constant (&zpAttr)) || (padConstAttr.size () != 1 )) {
67+ return false ;
68+ }
69+
70+ // Check equality
71+ int64_t zpVal = (*zpAttr.begin ()).getSExtValue ();
72+ int64_t padConstVal = (*padConstIntAttr.begin ()).getSExtValue ();
73+ return zpVal == padConstVal;
74+ }
75+
76+ // Bail-out on unsupported type
77+ return false ;
78+ }
79+
80+ namespace {
81+ template <typename OpTy>
82+ struct PoolPadFoldAdaptor ;
83+
84+ template <>
85+ struct PoolPadFoldAdaptor <tosa::AvgPool2dOp> {
86+ using OpTy = tosa::AvgPool2dOp;
87+ static bool checkKernelCompliance (OpTy op, const ArrayRef<int64_t > newPad) {
88+ const llvm::ArrayRef<int64_t > kernel = op.getKernel ();
89+ if (newPad[2 ] >= kernel[1 ] || newPad[3 ] >= kernel[1 ] ||
90+ newPad[0 ] >= kernel[0 ] || newPad[1 ] >= kernel[0 ])
91+ return false ;
92+ return true ;
93+ }
94+ static bool checkPadConstCompliance (OpTy op, Value padConst) {
95+ return checkMatchingPadConstAndZp (padConst, op.getInputZp ());
96+ }
97+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
98+ Value padInput, ArrayRef<int64_t > newPad) {
99+ rewriter.replaceOpWithNewOp <tosa::AvgPool2dOp>(
100+ op, op.getType (), padInput, op.getInputZp (), op.getOutputZp (),
101+ op.getKernel (), op.getStride (), rewriter.getDenseI64ArrayAttr (newPad),
102+ op.getAccType ());
103+ }
104+ };
105+
106+ template <>
107+ struct PoolPadFoldAdaptor <tosa::MaxPool2dOp> {
108+ using OpTy = tosa::MaxPool2dOp;
109+ static bool checkKernelCompliance (OpTy op, const ArrayRef<int64_t > newPad) {
110+ const llvm::ArrayRef<int64_t > kernel = op.getKernel ();
111+ if (newPad[2 ] >= kernel[1 ] || newPad[3 ] >= kernel[1 ] ||
112+ newPad[0 ] >= kernel[0 ] || newPad[1 ] >= kernel[0 ])
113+ return false ;
114+ return true ;
115+ }
116+ static bool checkPadConstCompliance (OpTy, Value padConst) {
117+ // Check that padConst is a constant value and a scalar tensor
118+ DenseElementsAttr padConstAttr;
119+ if (!matchPattern (padConst, m_Constant (&padConstAttr)) ||
120+ padConstAttr.size () != 1 ) {
121+ return false ;
122+ }
123+
124+ // Pad needs to be in the minimum value to be able to merge
125+ if (auto padConstFpAttr =
126+ mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127+ const APFloat padConstVal = *padConstFpAttr.begin ();
128+ const APFloat lowestVal =
129+ APFloat::getLargest (padConstVal.getSemantics (), true );
130+ return padConstVal == lowestVal;
131+ } else if (auto padConstIntAttr =
132+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
133+ const APInt padConstVal = *padConstIntAttr.begin ();
134+ const unsigned int bitWidth = padConstVal.getBitWidth ();
135+ const APInt lowestVal =
136+ padConstIntAttr.getElementType ().isUnsignedInteger ()
137+ ? APInt::getZero (bitWidth)
138+ : APInt::getSignedMinValue (bitWidth);
139+ return padConstVal == lowestVal;
140+ }
141+
142+ // Bail-out on unsupported type
143+ return false ;
144+ }
145+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
146+ Value padInput, ArrayRef<int64_t > newPad) {
147+ rewriter.replaceOpWithNewOp <tosa::MaxPool2dOp>(
148+ op, op.getType (), padInput, op.getKernel (), op.getStride (),
149+ rewriter.getDenseI64ArrayAttr (newPad), op.getNanMode ());
150+ }
151+ };
152+
153+ template <typename OpTy>
154+ struct ConvPadFoldAdaptor {
155+ static bool checkKernelCompliance (OpTy, const ArrayRef<int64_t >) {
156+ return true ;
157+ }
158+ static bool checkPadConstCompliance (OpTy op, Value padConst) {
159+ return checkMatchingPadConstAndZp (padConst, op.getInputZp ());
160+ }
161+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
162+ Value padInput, ArrayRef<int64_t > newPad) {
163+ rewriter.replaceOpWithNewOp <OpTy>(
164+ op, op.getResult ().getType (), padInput, op.getWeight (), op.getBias (),
165+ op.getInputZp (), op.getWeightZp (), newPad, op.getStrideAttr (),
166+ op.getDilationAttr (), op.getAccType (), op.getLocalBound ());
167+ }
168+ };
169+
170+ // Pattern attempts to fold a `tosa.pad` operator to a following tensor
171+ // operation like `tosa.conv2d` by merging the padding associated with the
172+ // pad operator directly to the implicit padding of the tensor operation.
173+ // This helps eliminate the explicit padding operator if unused.
174+ template <typename OpTy, typename AdaptorTy>
175+ struct FoldPadToTensorOp : public OpRewritePattern <OpTy> {
176+ using OpRewritePattern<OpTy>::OpRewritePattern;
177+
178+ LogicalResult matchAndRewrite (OpTy tensorOp,
179+ PatternRewriter &rewriter) const override {
180+ // Check producer is a tosa::PadOp
181+ auto padOp = tensorOp.getInput ().template getDefiningOp <tosa::PadOp>();
182+ if (!padOp)
183+ return rewriter.notifyMatchFailure (tensorOp,
184+ " Producer must be a tosa::PadOp." );
185+
186+ // Validate that tensor operation has sane padding
187+ const std::vector<int64_t > &tensorOpPad = tensorOp.getPad ().vec ();
188+ if (tensorOpPad.size () != 4 ) // pad_top, pad_bottom, pad_left, pad_right
189+ return rewriter.notifyMatchFailure (
190+ tensorOp, " Tensor operation padding shall have 4 elements." );
191+
192+ // Validate tosa::PadOp padding
193+ DenseIntElementsAttr padOpPadding;
194+ if (!matchPattern (padOp.getPadding (), m_Constant (&padOpPadding))) {
195+ return rewriter.notifyMatchFailure (
196+ tensorOp,
197+ " The `padding` input specified on the tosa::PadOp must be constant." );
198+ }
199+ // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
200+ // C_after
201+ if (padOpPadding.size () != 8 )
202+ return rewriter.notifyMatchFailure (tensorOp,
203+ " Pad padding should have 8 elements." );
204+ int64_t padNBefore = (*(padOpPadding.begin () + 0 )).getLimitedValue ();
205+ int64_t padNAfter = (*(padOpPadding.begin () + 1 )).getLimitedValue ();
206+ int64_t padHBefore = (*(padOpPadding.begin () + 2 )).getLimitedValue ();
207+ int64_t padHAfter = (*(padOpPadding.begin () + 3 )).getLimitedValue ();
208+ int64_t padWBefore = (*(padOpPadding.begin () + 4 )).getLimitedValue ();
209+ int64_t padWAfter = (*(padOpPadding.begin () + 5 )).getLimitedValue ();
210+ int64_t padCBefore = (*(padOpPadding.begin () + 6 )).getLimitedValue ();
211+ int64_t padCAfter = (*(padOpPadding.begin () + 7 )).getLimitedValue ();
212+
213+ if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0 )
214+ return rewriter.notifyMatchFailure (
215+ tensorOp, " Folding padding in N or C dimensions is not supported." );
216+
217+ // Fold padding from Pad into the tensor operation
218+ // 4 elements - pad_top, pad_bottom, pad_left, pad_right
219+ SmallVector<int64_t > foldedPad (tensorOpPad.size ());
220+ foldedPad[0 ] = padHBefore + tensorOpPad[0 ];
221+ foldedPad[1 ] = padHAfter + tensorOpPad[1 ];
222+ foldedPad[2 ] = padWBefore + tensorOpPad[2 ];
223+ foldedPad[3 ] = padWAfter + tensorOpPad[3 ];
224+
225+ // Check kernel related restrictions
226+ if (!AdaptorTy::checkKernelCompliance (tensorOp, foldedPad)) {
227+ return rewriter.notifyMatchFailure (
228+ tensorOp, " Padding size not aligned with kernel restrictions." );
229+ }
230+
231+ // Check padding constant restrictions
232+ if (!AdaptorTy::checkPadConstCompliance (tensorOp, padOp.getPadConst ())) {
233+ return rewriter.notifyMatchFailure (
234+ tensorOp,
235+ " Padding constant is not aligned with operator zero-point." );
236+ }
237+
238+ // Check that padding doesn't grow more than 8K level (8192) for now
239+ if (llvm::any_of (foldedPad, [](int64_t padVal) { return padVal > 8192 ; })) {
240+ return rewriter.notifyMatchFailure (
241+ tensorOp, " Padding size more than the 8K level limit." );
242+ }
243+
244+ // Create operator
245+ AdaptorTy::replaceOpWithNewPad (rewriter, tensorOp, padOp.getInput1 (),
246+ foldedPad);
247+
248+ return success ();
249+ }
250+ };
251+ } // namespace
252+
253+ void AvgPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
254+ MLIRContext *context) {
255+ results.add <FoldPadToTensorOp<tosa::AvgPool2dOp,
256+ PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
257+ context);
258+ }
259+
260+ void Conv2DOp::getCanonicalizationPatterns (RewritePatternSet &results,
261+ MLIRContext *context) {
262+ results.add <
263+ FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
264+ context);
265+ }
266+
267+ void DepthwiseConv2DOp::getCanonicalizationPatterns (RewritePatternSet &results,
268+ MLIRContext *context) {
269+ results.add <FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
270+ ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
271+ context);
272+ }
273+
274+ struct MaxPool2dIsNoOp : public OpRewritePattern <tosa::MaxPool2dOp> {
275+ using OpRewritePattern::OpRewritePattern;
276+
277+ LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
278+ PatternRewriter &rewriter) const override {
279+ Value input = op.getInput ();
280+ Value output = op.getOutput ();
281+ ShapedType inputType = llvm::cast<ShapedType>(input.getType ());
282+ ShapedType outputType = llvm::cast<ShapedType>(output.getType ());
283+
284+ if (!inputType.hasStaticShape () || !outputType.hasStaticShape ()) {
285+ return failure ();
286+ }
287+
288+ // If the output and input shapes are 1x1, then this is a no op.
289+ ArrayRef<int64_t > outputShape = outputType.getShape ();
290+ if (outputShape[1 ] != 1 || outputShape[2 ] != 1 ) {
291+ return failure ();
292+ }
293+
294+ ArrayRef<int64_t > inputShape = inputType.getShape ();
295+ if (inputShape[1 ] != 1 || inputShape[2 ] != 1 ) {
296+ return failure ();
297+ }
298+
299+ rewriter.replaceOp (op, input);
300+ return success ();
301+ }
302+ };
303+
304+ void MaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
305+ MLIRContext *context) {
306+ results.add <MaxPool2dIsNoOp,
307+ FoldPadToTensorOp<tosa::MaxPool2dOp,
308+ PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
309+ context);
310+ }
311+
312+ // ===----------------------------------------------------------------------===//
313+ // Data Layout / Memory Reinterpretation.
314+ // ===----------------------------------------------------------------------===//
315+
42316struct ConcatOptimization : public OpRewritePattern <tosa::ConcatOp> {
43317 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44318
@@ -175,41 +449,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175449 results.add <ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176450}
177451
178- struct MaxPool2dIsNoOp : public OpRewritePattern <tosa::MaxPool2dOp> {
179- using OpRewritePattern::OpRewritePattern;
180-
181- LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
182- PatternRewriter &rewriter) const override {
183- Value input = op.getInput ();
184- Value output = op.getOutput ();
185- ShapedType inputType = llvm::cast<ShapedType>(input.getType ());
186- ShapedType outputType = llvm::cast<ShapedType>(output.getType ());
187-
188- if (!inputType.hasStaticShape () || !outputType.hasStaticShape ()) {
189- return failure ();
190- }
191-
192- // If the output and input shapes are 1x1, then this is a no op.
193- ArrayRef<int64_t > outputShape = outputType.getShape ();
194- if (outputShape[1 ] != 1 || outputShape[2 ] != 1 ) {
195- return failure ();
196- }
197-
198- ArrayRef<int64_t > inputShape = inputType.getShape ();
199- if (inputShape[1 ] != 1 || inputShape[2 ] != 1 ) {
200- return failure ();
201- }
202-
203- rewriter.replaceOp (op, input);
204- return success ();
205- }
206- };
207-
208- void MaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
209- MLIRContext *context) {
210- results.add <MaxPool2dIsNoOp>(context);
211- }
212-
213452struct ClampIsNoOp : public OpRewritePattern <tosa::ClampOp> {
214453 using OpRewritePattern::OpRewritePattern;
215454
0 commit comments