@@ -331,14 +331,14 @@ struct UnitExtentReplacementInfo {
331
331
SmallVector<int64_t > targetShape;
332
332
};
333
333
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata (
334
- MLIRContext *context, GenericOp genericOp , OpOperand *opOperand,
334
+ MLIRContext *context, IndexingMapOpInterface op , OpOperand *opOperand,
335
335
llvm::SmallDenseMap<unsigned , unsigned > &oldDimsToNewDimsMap,
336
336
ArrayRef<AffineExpr> dimReplacements) {
337
337
UnitExtentReplacementInfo info;
338
338
ReassociationIndices reassociationGroup;
339
339
SmallVector<AffineExpr> newIndexExprs;
340
- AffineMap indexingMap = genericOp .getMatchingIndexingMap (opOperand);
341
- ArrayRef <int64_t > operandShape = genericOp. getShape (opOperand);
340
+ AffineMap indexingMap = op .getMatchingIndexingMap (opOperand);
341
+ SmallVector <int64_t > operandShape = op. getStaticOperandShape (opOperand);
342
342
ArrayRef<AffineExpr> exprs = indexingMap.getResults ();
343
343
344
344
auto isUnitDim = [&](unsigned dim) {
@@ -380,9 +380,16 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
380
380
}
381
381
382
382
FailureOr<DropUnitDimsResult>
383
- linalg::dropUnitDims (RewriterBase &rewriter, GenericOp genericOp,
383
+ linalg::dropUnitDims (RewriterBase &rewriter, IndexingMapOpInterface op,
384
+ const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
384
385
const ControlDropUnitDims &options) {
385
- SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray ();
386
+ auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation ());
387
+ if (!dpsOp) {
388
+ return rewriter.notifyMatchFailure (
389
+ op, " op should implement DestinationStyleOpInterface" );
390
+ }
391
+
392
+ SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray ();
386
393
if (indexingMaps.empty ())
387
394
return failure ();
388
395
@@ -392,19 +399,19 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
392
399
AffineMap invertedMap =
393
400
inversePermutation (concatAffineMaps (indexingMaps, rewriter.getContext ()));
394
401
if (!invertedMap) {
395
- return rewriter.notifyMatchFailure (genericOp ,
402
+ return rewriter.notifyMatchFailure (op ,
396
403
" invalid indexing maps for operation" );
397
404
}
398
405
399
406
SmallVector<int64_t > allShapesSizes;
400
- for (OpOperand &opOperand : genericOp ->getOpOperands ())
401
- llvm::append_range (allShapesSizes, genericOp. getShape (&opOperand));
407
+ for (OpOperand &opOperand : op ->getOpOperands ())
408
+ llvm::append_range (allShapesSizes, op. getStaticOperandShape (&opOperand));
402
409
403
410
// 1a. Get the allowed list of dimensions to drop from the `options`.
404
- SmallVector<unsigned > allowedUnitDims = options.controlFn (genericOp );
411
+ SmallVector<unsigned > allowedUnitDims = options.controlFn (op );
405
412
if (allowedUnitDims.empty ()) {
406
413
return rewriter.notifyMatchFailure (
407
- genericOp , " control function returns no allowed unit dims to prune" );
414
+ op , " control function returns no allowed unit dims to prune" );
408
415
}
409
416
llvm::SmallDenseSet<unsigned > unitDimsFilter (allowedUnitDims.begin (),
410
417
allowedUnitDims.end ());
@@ -417,19 +424,16 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
417
424
}
418
425
}
419
426
420
- // 2. Compute the iterator types of the modified op by dropping the one-trip
427
+ // 2. Compute the new loops of the modified op by dropping the one-trip
421
428
// count loops.
422
- SmallVector<utils::IteratorType> newIteratorTypes;
423
429
llvm::SmallDenseMap<unsigned , unsigned > oldDimToNewDimMap;
424
430
SmallVector<AffineExpr> dimReplacements;
425
431
unsigned newDims = 0 ;
426
- for (auto [index, attr] :
427
- llvm::enumerate (genericOp.getIteratorTypesArray ())) {
432
+ for (auto index : llvm::seq<int64_t >(op.getStaticLoopRanges ().size ())) {
428
433
if (unitDims.count (index)) {
429
434
dimReplacements.push_back (
430
435
getAffineConstantExpr (0 , rewriter.getContext ()));
431
436
} else {
432
- newIteratorTypes.push_back (attr);
433
437
oldDimToNewDimMap[index] = newDims;
434
438
dimReplacements.push_back (
435
439
getAffineDimExpr (newDims, rewriter.getContext ()));
@@ -462,9 +466,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
462
466
}
463
467
return false ;
464
468
};
465
- for (OpOperand &opOperand : genericOp ->getOpOperands ()) {
466
- auto indexingMap = genericOp .getMatchingIndexingMap (&opOperand);
467
- ArrayRef <int64_t > shape = genericOp. getShape (&opOperand);
469
+ for (OpOperand &opOperand : op ->getOpOperands ()) {
470
+ auto indexingMap = op .getMatchingIndexingMap (&opOperand);
471
+ SmallVector <int64_t > shape = op. getStaticOperandShape (&opOperand);
468
472
if (!hasCollapsibleType (opOperand)) {
469
473
AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols (
470
474
dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size (), 0 );
@@ -474,9 +478,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
474
478
reassociations.push_back ({});
475
479
continue ;
476
480
}
477
- auto replacementInfo = dropUnitExtentFromOperandMetadata (
478
- rewriter.getContext (), genericOp , &opOperand, oldDimToNewDimMap ,
479
- dimReplacements);
481
+ auto replacementInfo =
482
+ dropUnitExtentFromOperandMetadata ( rewriter.getContext (), op , &opOperand,
483
+ oldDimToNewDimMap, dimReplacements);
480
484
reassociations.push_back (replacementInfo.reassociation );
481
485
newIndexingMaps.push_back (replacementInfo.indexMap );
482
486
targetShapes.push_back (replacementInfo.targetShape );
@@ -491,13 +495,13 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
491
495
concatAffineMaps (newIndexingMaps, rewriter.getContext ())))
492
496
return failure ();
493
497
494
- Location loc = genericOp .getLoc ();
498
+ Location loc = op .getLoc ();
495
499
// 4. For each of the operands, collapse the operand to convert
496
500
// from original shape to shape in the modified operation if needed,
497
501
// either through use of reshapes or rank-reducing slices as
498
502
// specified in `options`.
499
503
SmallVector<Value> newOperands;
500
- for (OpOperand &opOperand : genericOp ->getOpOperands ()) {
504
+ for (OpOperand &opOperand : op ->getOpOperands ()) {
501
505
int64_t idx = opOperand.getOperandNumber ();
502
506
if (!collapsed[idx]) {
503
507
newOperands.push_back (opOperand.get ());
@@ -508,31 +512,15 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
508
512
options.rankReductionStrategy ));
509
513
}
510
514
511
- // 5. Create the `linalg.generic` operation with the new operands,
512
- // indexing maps, iterator types and result types.
513
- ArrayRef<Value> newInputs =
514
- ArrayRef<Value>(newOperands).take_front (genericOp.getNumDpsInputs ());
515
- ArrayRef<Value> newOutputs =
516
- ArrayRef<Value>(newOperands).take_back (genericOp.getNumDpsInits ());
517
- SmallVector<Type> resultTypes;
518
- resultTypes.reserve (genericOp.getNumResults ());
519
- for (unsigned i : llvm::seq<unsigned >(0 , genericOp.getNumResults ()))
520
- resultTypes.push_back (newOutputs[i].getType ());
521
- GenericOp replacementOp =
522
- rewriter.create <GenericOp>(loc, resultTypes, newInputs, newOutputs,
523
- newIndexingMaps, newIteratorTypes);
524
- rewriter.inlineRegionBefore (genericOp.getRegion (), replacementOp.getRegion (),
525
- replacementOp.getRegion ().begin ());
526
- // 5a. Replace `linalg.index` operations that refer to the dropped unit
527
- // dimensions.
528
- replaceUnitDimIndexOps (replacementOp, unitDims, rewriter);
515
+ IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder (
516
+ loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
529
517
530
518
// 6. If any result type changes, insert a reshape/slice to convert from the
531
519
// original type to the new type.
532
520
SmallVector<Value> resultReplacements;
533
- for (auto [index, result] : llvm::enumerate (replacementOp. getResults ())) {
534
- unsigned opOperandIndex = index + replacementOp .getNumDpsInputs ();
535
- Value origDest = genericOp .getDpsInitOperand (index)->get ();
521
+ for (auto [index, result] : llvm::enumerate (replacementOp-> getResults ())) {
522
+ unsigned opOperandIndex = index + dpsOp .getNumDpsInputs ();
523
+ Value origDest = dpsOp .getDpsInitOperand (index)->get ();
536
524
if (!collapsed[opOperandIndex]) {
537
525
resultReplacements.push_back (result);
538
526
continue ;
@@ -546,6 +534,51 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
546
534
return DropUnitDimsResult{replacementOp, resultReplacements};
547
535
}
548
536
537
+ FailureOr<DropUnitDimsResult>
538
+ linalg::dropUnitDims (RewriterBase &rewriter, GenericOp genericOp,
539
+ const ControlDropUnitDims &options) {
540
+
541
+ DroppedUnitDimsBuilder build =
542
+ [](Location loc, OpBuilder &b, IndexingMapOpInterface op,
543
+ ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
544
+ const llvm::SmallDenseSet<unsigned > &droppedDims)
545
+ -> IndexingMapOpInterface {
546
+ auto genericOp = cast<GenericOp>(op);
547
+ // Compute the iterator types of the modified op by dropping the one-trip
548
+ // count loops.
549
+ SmallVector<utils::IteratorType> newIteratorTypes;
550
+ for (auto [index, attr] :
551
+ llvm::enumerate (genericOp.getIteratorTypesArray ())) {
552
+ if (!droppedDims.count (index))
553
+ newIteratorTypes.push_back (attr);
554
+ }
555
+
556
+ // Create the `linalg.generic` operation with the new operands,
557
+ // indexing maps, iterator types and result types.
558
+ ArrayRef<Value> newInputs =
559
+ ArrayRef<Value>(newOperands).take_front (genericOp.getNumDpsInputs ());
560
+ ArrayRef<Value> newOutputs =
561
+ ArrayRef<Value>(newOperands).take_back (genericOp.getNumDpsInits ());
562
+ SmallVector<Type> resultTypes;
563
+ resultTypes.reserve (genericOp.getNumResults ());
564
+ for (unsigned i : llvm::seq<unsigned >(0 , genericOp.getNumResults ()))
565
+ resultTypes.push_back (newOutputs[i].getType ());
566
+ GenericOp replacementOp =
567
+ b.create <GenericOp>(loc, resultTypes, newInputs, newOutputs,
568
+ newIndexingMaps, newIteratorTypes);
569
+ b.cloneRegionBefore (genericOp.getRegion (), replacementOp.getRegion (),
570
+ replacementOp.getRegion ().begin ());
571
+ // 5a. Replace `linalg.index` operations that refer to the dropped unit
572
+ // dimensions.
573
+ IRRewriter rewriter (b);
574
+ replaceUnitDimIndexOps (replacementOp, droppedDims, rewriter);
575
+
576
+ return replacementOp;
577
+ };
578
+
579
+ return dropUnitDims (rewriter, genericOp, build, options);
580
+ }
581
+
549
582
namespace {
550
583
struct DropUnitDims : public OpRewritePattern <GenericOp> {
551
584
DropUnitDims (MLIRContext *context, ControlDropUnitDims options = {},
0 commit comments