@@ -112,6 +112,63 @@ static MapScatterOp foldTransposeIntoMapScatter(RewriterBase &rewriter,
112
112
return mapScatterOp;
113
113
}
114
114
115
+ // / Fold a tensor::ExpandShapeOp or tensor::CollapseShapeOp into a consumer
116
+ // / `mapScatterOp`, by linearizing and then delinearizing the source indices
117
+ // / of the `mapScatterOp`s index transformation.
118
+ template <typename ReshapeOpTy>
119
+ static IREE::LinalgExt::MapScatterOp
120
+ foldReshapeIntoMapScatter (RewriterBase &rewriter, ReshapeOpTy reshapeOp,
121
+ IREE::LinalgExt::MapScatterOp mapScatterOp) {
122
+ assert (mapScatterOp.getInput () == reshapeOp->getResult (0 ) &&
123
+ " expected reshapeOp to be the producer of mapScatterOp" );
124
+ Location loc = reshapeOp->getLoc ();
125
+ OpBuilder::InsertionGuard g (rewriter);
126
+ rewriter.setInsertionPointAfter (reshapeOp);
127
+ SmallVector<OpFoldResult> srcDims =
128
+ tensor::getMixedSizes (rewriter, loc, reshapeOp.getSrc ());
129
+ // There can be leftover tensor.dim ops consuming the result of the reshape,
130
+ // but they are expected to be folded into some affine.apply ops on the source
131
+ // sizes by later cleanup patterns.
132
+ SmallVector<OpFoldResult> resultDims =
133
+ tensor::getMixedSizes (rewriter, loc, reshapeOp.getResult ());
134
+
135
+ auto indexTransformBuilder =
136
+ [&](ArrayRef<BlockArgument> srcIndices) -> SmallVector<Value> {
137
+ auto linearizeIndexOp = rewriter.create <affine::AffineLinearizeIndexOp>(
138
+ mapScatterOp->getLoc (), srcIndices, srcDims, /* disjoint=*/ true );
139
+ auto delinearizeIndexOp = rewriter.create <affine::AffineDelinearizeIndexOp>(
140
+ mapScatterOp->getLoc (), linearizeIndexOp.getResult (), resultDims,
141
+ /* hasOuterBound=*/ true );
142
+ return delinearizeIndexOp->getResults ();
143
+ };
144
+ rewriter.modifyOpInPlace (mapScatterOp, [&]() {
145
+ mapScatterOp.insertTransformationAtStart (rewriter, indexTransformBuilder,
146
+ srcDims.size ());
147
+ mapScatterOp.getInputMutable ().assign (reshapeOp->getOperand (0 ));
148
+ });
149
+ return mapScatterOp;
150
+ }
151
+
152
+ // / Fold a tensor::ExpandShapeOp into a consumer `mapScatterOp`, by linearizing
153
+ // / and then delinearizing the source indices of the `mapScatterOp`s index
154
+ // / transformation.
155
+ static MapScatterOp
156
+ foldExpandShapeIntoMapScatter (RewriterBase &rewriter,
157
+ tensor::ExpandShapeOp expandShapeOp,
158
+ MapScatterOp mapScatterOp) {
159
+ return foldReshapeIntoMapScatter (rewriter, expandShapeOp, mapScatterOp);
160
+ }
161
+
162
+ // / Fold a tensor::CollapseShapeOp into a consumer `mapScatterOp`, by
163
+ // / linearizing and then delinearizing the source indices of the
164
+ // / `mapScatterOp`s index transformation.
165
+ static MapScatterOp
166
+ foldCollapseShapeIntoMapScatter (RewriterBase &rewriter,
167
+ tensor::CollapseShapeOp collapseShapeOp,
168
+ MapScatterOp mapScatterOp) {
169
+ return foldReshapeIntoMapScatter (rewriter, collapseShapeOp, mapScatterOp);
170
+ }
171
+
115
172
// / Fold an `extractSliceOp` into a consumer `mapScatterOp` by applying a mask
116
173
// / based on the bounds of the extractSliceOp. Currently, only zero offsets and
117
174
// / unit strides are supported.
@@ -219,13 +276,7 @@ static void buildNestedDistributionLoops(
219
276
});
220
277
}
221
278
222
- // / Fold a tensor.pad op into a iree_linalg_ext.map_scatter op, and separate
223
- // / the writing of padding values into a separate operation on the buffer that
224
- // / the map_scatter op is ultimately written into. The result buffer is taken
225
- // / from the direct consumer of the `mapScatterOp`, which is expected to be an
226
- // / `iree_codegen.store_to_buffer` op. Return failure if the result buffer is
227
- // / not found.
228
- static FailureOr<MapScatterOp>
279
+ FailureOr<MapScatterOp>
229
280
foldPadIntoMapScatter (RewriterBase &rewriter, tensor::PadOp padOp,
230
281
MapScatterOp mapScatterOp,
231
282
PadDistributionConfigFn padDistributionConfigFn) {
@@ -316,14 +367,9 @@ foldPadIntoMapScatter(RewriterBase &rewriter, tensor::PadOp padOp,
316
367
return mapScatterOp;
317
368
}
318
369
319
- // / Fold the `op` into the `mapScatterOp`, if possible. The resulting
320
- // / map_scatter op is returned, if the `op` was folded. Otherwise, return
321
- // / failure. For `PadOp`s, use the `padDistributionConfigFn` to distribute
322
- // / the writing of padding values to the corresponding output buffer.
323
- static FailureOr<MapScatterOp>
324
- foldIntoMapScatter (RewriterBase &rewriter, Operation *op,
325
- MapScatterOp mapScatterOp,
326
- PadDistributionConfigFn padDistributionConfigFn) {
370
+ FailureOr<MapScatterOp> foldIntoMapScatter (RewriterBase &rewriter,
371
+ Operation *op,
372
+ MapScatterOp mapScatterOp) {
327
373
return llvm::TypeSwitch<Operation *, FailureOr<MapScatterOp>>(op)
328
374
.Case <linalg::CopyOp>([&](linalg::CopyOp copyOp) {
329
375
return foldIdentityLikeOpIntoMapScatter (rewriter, copyOp, mapScatterOp);
@@ -342,47 +388,9 @@ foldIntoMapScatter(RewriterBase &rewriter, Operation *op,
342
388
return foldExtractSliceIntoMapScatter (rewriter, extractSliceOp,
343
389
mapScatterOp);
344
390
})
345
- .Case <tensor::PadOp>([&](tensor::PadOp padOp) {
346
- return foldPadIntoMapScatter (rewriter, padOp, mapScatterOp,
347
- padDistributionConfigFn);
348
- })
349
391
.Default ([](Operation *) { return failure (); });
350
392
}
351
393
352
- // / Starting from the `root`, iteratively combine any relayout op producers
353
- // / into a single iree_linalg_ext.map_scatter op. An identity map_scatter op
354
- // / is inserted before the root, and then the producers of the map_scatter op
355
- // / are folded into the map_scatter until an unsupported op is reached.
356
- static void
357
- combineRelayoutOpChain (RewriterBase &rewriter, MapScatterOp mapScatterOp,
358
- PadDistributionConfigFn padDistributionConfigFn) {
359
- Operation *relayoutOp = mapScatterOp.getInput ().getDefiningOp ();
360
- if (!relayoutOp) {
361
- return ;
362
- }
363
- MapScatterOp combinedRelayoutOp = mapScatterOp;
364
- while (relayoutOp) {
365
- LDBG () << " Attempting to fold " << relayoutOp->getName ()
366
- << " into map_scatter op:\n "
367
- << *relayoutOp;
368
- FailureOr<MapScatterOp> maybeCombinedRelayoutOp = foldIntoMapScatter (
369
- rewriter, relayoutOp, combinedRelayoutOp, padDistributionConfigFn);
370
- if (failed (maybeCombinedRelayoutOp)) {
371
- LDBG () << " Failed to fold " << relayoutOp->getName ()
372
- << " into map_scatter op" ;
373
- break ;
374
- }
375
- combinedRelayoutOp = maybeCombinedRelayoutOp.value ();
376
- LDBG () << " Successfully folded " << relayoutOp->getName ()
377
- << " into map_scatter. New map_scatter op:\n "
378
- << combinedRelayoutOp;
379
- relayoutOp = combinedRelayoutOp.getInput ().getDefiningOp ();
380
- }
381
- if (combinedRelayoutOp.isIdentity ()) {
382
- rewriter.replaceOp (combinedRelayoutOp, combinedRelayoutOp.getInput ());
383
- }
384
- }
385
-
386
394
// Insert identity map_scatter op after the root and replace all uses.
387
395
static MapScatterOp insertIdentityMapScatter (RewriterBase &rewriter,
388
396
OpResult root) {
@@ -406,36 +414,50 @@ static MapScatterOp insertIdentityMapScatter(RewriterBase &rewriter,
406
414
return mapScatterOp;
407
415
}
408
416
409
- static bool isSupportedRelayoutOp (Operation *op) {
417
+ bool isSupportedRelayoutOp (Operation *op) {
410
418
return isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp,
411
419
tensor::ExtractSliceOp, tensor::PadOp, linalg::CopyOp,
412
420
linalg::TransposeOp>(op);
413
421
}
414
422
415
- // / Returns the leaves of all relayout op chains in the funcOp. A relayout op
416
- // / chain is a sequence of relayout ops (defined by `isSupportedRelayoutOp`)
417
- // / for which the only users of the ops in the chain are relayout ops, except
418
- // / for the leaves of the chain. The leaves are simply relayout ops that have
419
- // / non relayout op users.
420
- static SmallVector<OpResult> getRelayoutLeaves (FunctionOpInterface funcOp) {
421
- SmallVector<OpResult> relayoutChainRoots;
422
- funcOp->walk ([&relayoutChainRoots](Operation *op) {
423
+ // / Insert identity map_scatter ops after the given operation if it is a valid
424
+ // / leaf op of a relayout op chain. A relayout op chain is a sequence of
425
+ // / relayout ops (defined by `isSupportedRelayoutOp`) for which the only users
426
+ // / of the ops in the chain are relayout ops, except for the leaves of the
427
+ // / chain. The leaves are simply relayout ops that have non relayout op users.
428
+ // / The `controlFn` is a callback on the leaf OpResult that provides control
429
+ // / over whether or not to insert a map_scatter op.
430
+ struct InsertMapScatterOpPattern : public RewritePattern {
431
+ InsertMapScatterOpPattern (MLIRContext *context,
432
+ CombineRelayoutOpsControlFn controlFn = nullptr ,
433
+ PatternBenefit benefit = 1 )
434
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
435
+ controlFn (controlFn) {}
436
+
437
+ LogicalResult matchAndRewrite (Operation *op,
438
+ PatternRewriter &rewriter) const override {
423
439
if (!isSupportedRelayoutOp (op)) {
424
- return WalkResult::advance ();
440
+ return failure ();
425
441
}
426
442
// Relayout ops with only relayout op users are not leaves.
427
443
auto isDimOrSupportedRelayoutOp = [](Operation *op) {
428
444
return isSupportedRelayoutOp (op) || isa<tensor::DimOp>(op);
429
445
};
430
446
if (llvm::all_of (op->getUsers (), isDimOrSupportedRelayoutOp)) {
431
- return WalkResult::advance ();
447
+ return failure ();
432
448
}
433
449
// All relayout ops have a single result.
434
- relayoutChainRoots.push_back (op->getResult (0 ));
435
- return WalkResult::advance ();
436
- });
437
- return relayoutChainRoots;
438
- }
450
+ OpResult leaf = op->getResult (0 );
451
+ if (controlFn && !controlFn (leaf)) {
452
+ return failure ();
453
+ }
454
+ (void )insertIdentityMapScatter (rewriter, leaf);
455
+ return success ();
456
+ }
457
+
458
+ private:
459
+ CombineRelayoutOpsControlFn controlFn;
460
+ };
439
461
440
462
LogicalResult
441
463
combineLayoutTransformation (MLIRContext *ctx, FunctionOpInterface funcOp,
@@ -499,24 +521,24 @@ combineLayoutTransformation(MLIRContext *ctx, FunctionOpInterface funcOp,
499
521
IRRewriter rewriter (ctx);
500
522
simplifyComplexRelayoutOps (rewriter, funcOp);
501
523
502
- // Start from leaf ops, and combine producer relayout ops into a single
503
- // map_scatter.
504
- SmallVector<OpResult> relayoutLeaves = getRelayoutLeaves (funcOp);
505
- for (OpResult leaf : relayoutLeaves) {
506
- if (controlFn && !controlFn (leaf)) {
507
- continue ;
508
- }
509
- MapScatterOp mapScatterOp = insertIdentityMapScatter (rewriter, leaf);
510
- combineRelayoutOpChain (rewriter, mapScatterOp, padDistributionConfigFn);
511
- }
512
-
513
- // Cleanup any tensor.dim ops that may be present after relayout
514
- // combination.
515
- RewritePatternSet cleanupPatterns (ctx);
516
- memref::populateResolveRankedShapedTypeResultDimsPatterns (cleanupPatterns);
517
- if (failed (applyPatternsGreedily (funcOp, std::move (cleanupPatterns)))) {
524
+ // Combine relayout operations into new the map_scatter ops.
525
+ RewritePatternSet relayoutCombinationPatterns (ctx);
526
+ relayoutCombinationPatterns.add <InsertMapScatterOpPattern>(ctx, controlFn);
527
+ populateCombineRelayoutOpPatterns (relayoutCombinationPatterns,
528
+ padDistributionConfigFn);
529
+ memref::populateResolveRankedShapedTypeResultDimsPatterns (
530
+ relayoutCombinationPatterns);
531
+ if (failed (applyPatternsGreedily (funcOp,
532
+ std::move (relayoutCombinationPatterns)))) {
518
533
return failure ();
519
534
}
535
+
536
+ // Clean up any identity map_scatter ops after combining.
537
+ funcOp->walk ([&](MapScatterOp mapScatterOp) {
538
+ if (mapScatterOp.isIdentity ()) {
539
+ rewriter.replaceOp (mapScatterOp, mapScatterOp.getInput ());
540
+ }
541
+ });
520
542
return success ();
521
543
}
522
544
0 commit comments