77// ===----------------------------------------------------------------------===//
88
99#include " mlir/Dialect/Linalg/IR/Linalg.h"
10+ #include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1011#include " mlir/Dialect/Tensor/IR/Tensor.h"
1112#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
1213#include " mlir/Dialect/Utils/IndexingUtils.h"
@@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
197198// / Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198199// / the pad op has zero low paddings, or if `pack` has no padding values.
199200struct FoldPadWithPackOp : public OpRewritePattern <PackOp> {
200- using OpRewritePattern<PackOp>::OpRewritePattern;
201+ public:
202+ FoldPadWithPackOp (MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
203+ : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
201204
202205 LogicalResult matchAndRewrite (PackOp packOp,
203206 PatternRewriter &rewriter) const override {
@@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
206209 if (!padOp || padOp.getNofold () || !padOp.hasZeroLowPad ())
207210 return failure ();
208211
212+ // User controlled folding function.
213+ if (controlFn && !controlFn (&packOp.getSourceMutable ()))
214+ return failure ();
215+
209216 Value constantPaddingValue = padOp.getConstantPaddingValue ();
210217 if (!constantPaddingValue)
211218 return failure ();
@@ -220,20 +227,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
220227 packOp.getOuterDimsPerm ());
221228 return success ();
222229 }
230+
231+ private:
232+ ControlFoldIntoPackUnpackFn controlFn;
223233};
224234
225235// / Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226236// / has extract_slice semantics.
227237struct FoldUnpackWithExtractSliceOp
228238 : public OpRewritePattern<tensor::ExtractSliceOp> {
229- using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
239+ public:
240+ FoldUnpackWithExtractSliceOp (MLIRContext *context,
241+ ControlFoldIntoPackUnpackFn controlFn)
242+ : OpRewritePattern<tensor::ExtractSliceOp>(context),
243+ controlFn (std::move(controlFn)) {}
230244
231245 LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
232246 PatternRewriter &rewriter) const override {
233247 auto unpackOp = sliceOp.getSource ().getDefiningOp <UnPackOp>();
234248 if (!unpackOp)
235249 return failure ();
236250
251+ // User controlled folding function.
252+ if (controlFn && !controlFn (&sliceOp.getSourceMutable ()))
253+ return failure ();
254+
237255 if (sliceOp.getResultType ().getRank () != unpackOp.getDestType ().getRank ()) {
238256 return rewriter.notifyMatchFailure (
239257 sliceOp, " rank-reduced folding is not supported" );
@@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp
255273 unpackOp.getMixedTiles (), unpackOp.getOuterDimsPerm ());
256274 return success ();
257275 }
276+
277+ private:
278+ ControlFoldIntoPackUnpackFn controlFn;
258279};
259280
260281// Applies 'permutation' on 'inVec' and stores the result in resVec.
@@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
284305// / semantics.
285306struct FoldProducerPackWithConsumerLinalgTransposeOp
286307 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
287- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
308+
309+ public:
310+ FoldProducerPackWithConsumerLinalgTransposeOp (
311+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
312+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
313+ controlFn (std::move(controlFn)) {}
288314
289315 LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp,
290316 PatternRewriter &rewriter) const override {
@@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
293319 if (!packOp)
294320 return failure ();
295321
322+ // User controlled folding function.
323+ if (controlFn && !controlFn (&linalgOp->getOpOperand (0 )))
324+ return failure ();
325+
296326 FailureOr<SmallVector<int64_t >> maybePerm =
297327 getTransposeOpPermutation (linalgOp);
298328 if (failed (maybePerm))
@@ -331,20 +361,31 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
331361
332362 return success ();
333363 }
364+
365+ private:
366+ ControlFoldIntoPackUnpackFn controlFn;
334367};
335368
336369// / Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
337370// / semantics.
338371struct FoldConsumerPackWithProducerLinalgTransposeOp
339372 : public OpRewritePattern<PackOp> {
340- using OpRewritePattern<PackOp>::OpRewritePattern;
373+
374+ public:
375+ FoldConsumerPackWithProducerLinalgTransposeOp (
376+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
377+ : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
341378
342379 LogicalResult matchAndRewrite (PackOp packOp,
343380 PatternRewriter &rewriter) const override {
344381 auto linalgOp = packOp.getSource ().getDefiningOp <linalg::LinalgOp>();
345382 if (!linalgOp)
346383 return failure ();
347384
385+ // User controlled folding function.
386+ if (controlFn && !controlFn (&packOp.getSourceMutable ()))
387+ return failure ();
388+
348389 FailureOr<SmallVector<int64_t >> maybePerm =
349390 getTransposeOpPermutation (linalgOp);
350391 if (failed (maybePerm))
@@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
375416
376417 return success ();
377418 }
419+
420+ private:
421+ ControlFoldIntoPackUnpackFn controlFn;
378422};
379423
380424// / Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
381425// / transpose semantics.
382426struct FoldProducerUnPackWithConsumerLinalgTransposeOp
383427 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
384- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
428+
429+ public:
430+ FoldProducerUnPackWithConsumerLinalgTransposeOp (
431+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
432+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
433+ controlFn (std::move(controlFn)) {}
385434
386435 LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp,
387436 PatternRewriter &rewriter) const override {
@@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
390439 if (!unPackOp)
391440 return failure ();
392441
442+ // User controlled folding function.
443+ if (controlFn && !controlFn (&linalgOp->getOpOperand (0 )))
444+ return failure ();
445+
393446 FailureOr<SmallVector<int64_t >> maybePerm =
394447 getTransposeOpPermutation (linalgOp);
395448 if (failed (maybePerm))
@@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
416469
417470 return success ();
418471 }
472+
473+ private:
474+ ControlFoldIntoPackUnpackFn controlFn;
419475};
420476
421477// / Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
@@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
424480 : public OpRewritePattern<UnPackOp> {
425481 using OpRewritePattern<UnPackOp>::OpRewritePattern;
426482
483+ public:
484+ FoldConsumerUnPackWithProducerLinalgTransposeOp (
485+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
486+ : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
487+
427488 LogicalResult matchAndRewrite (UnPackOp unPackOp,
428489 PatternRewriter &rewriter) const override {
429490 auto linalgOp = unPackOp.getSource ().getDefiningOp <linalg::LinalgOp>();
430491 if (!linalgOp)
431492 return failure ();
432493
494+ // User controlled folding function.
495+ if (controlFn && !controlFn (&unPackOp.getSourceMutable ()))
496+ return failure ();
497+
433498 FailureOr<SmallVector<int64_t >> maybePerm =
434499 getTransposeOpPermutation (linalgOp);
435500 if (failed (maybePerm))
@@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
474539
475540 return success ();
476541 }
542+
543+ private:
544+ ControlFoldIntoPackUnpackFn controlFn;
477545};
478546
479547// / tensor.empty does not define any tensor contents, so an unpadded pack
@@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
521589
522590} // namespace
523591
524- void populateFoldIntoPackAndUnpackPatterns (RewritePatternSet &patterns) {
592+ void populateFoldIntoPackAndUnpackPatterns (
593+ RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
525594 patterns.insert <FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
526595 FoldProducerPackWithConsumerLinalgTransposeOp,
527596 FoldConsumerPackWithProducerLinalgTransposeOp,
528597 FoldConsumerUnPackWithProducerLinalgTransposeOp,
529598 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
530- patterns.getContext ());
599+ patterns.getContext (), controlFn );
531600}
532601
533602void populateSimplifyPackAndUnpackPatterns (RewritePatternSet &patterns) {
0 commit comments