Skip to content

Commit 3ebe5d6

Browse files
IanWood1Groverkss
andauthored
[mlir][linalg] Drop unit dims on IndexingMapOpInterface (#150280)
Generalizes `dropUnitDims` to operate on any op implementing the `IndexingMapOpInterface`. Operation specific creation is handled by passing a builder that will construct the new operation based on the dropped dimensions. --------- Signed-off-by: Ian Wood <[email protected]> Co-authored-by: Kunwar Grover <[email protected]>
1 parent 8f8b436 commit 3ebe5d6

File tree

2 files changed

+88
-45
lines changed

2 files changed

+88
-45
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,20 @@ struct ControlDropUnitDims {
537537
return SmallVector<unsigned>{};
538538
};
539539
};
540+
540541
struct DropUnitDimsResult {
541-
linalg::GenericOp resultOp;
542+
IndexingMapOpInterface resultOp;
542543
SmallVector<Value> replacements;
543544
};
545+
using DroppedUnitDimsBuilder = std::function<IndexingMapOpInterface(
546+
Location loc, OpBuilder &, IndexingMapOpInterface,
547+
ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
548+
const llvm::SmallDenseSet<unsigned> &droppedDims)>;
549+
550+
FailureOr<DropUnitDimsResult>
551+
dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
552+
const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
553+
const ControlDropUnitDims &options);
544554
FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
545555
GenericOp genericOp,
546556
const ControlDropUnitDims &options);

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 77 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,14 @@ struct UnitExtentReplacementInfo {
331331
SmallVector<int64_t> targetShape;
332332
};
333333
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
334-
MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
334+
MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand,
335335
llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
336336
ArrayRef<AffineExpr> dimReplacements) {
337337
UnitExtentReplacementInfo info;
338338
ReassociationIndices reassociationGroup;
339339
SmallVector<AffineExpr> newIndexExprs;
340-
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
341-
ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
340+
AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
341+
SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand);
342342
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
343343

344344
auto isUnitDim = [&](unsigned dim) {
@@ -380,9 +380,16 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
380380
}
381381

382382
FailureOr<DropUnitDimsResult>
383-
linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
383+
linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
384+
const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
384385
const ControlDropUnitDims &options) {
385-
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
386+
auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
387+
if (!dpsOp) {
388+
return rewriter.notifyMatchFailure(
389+
op, "op should implement DestinationStyleOpInterface");
390+
}
391+
392+
SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
386393
if (indexingMaps.empty())
387394
return failure();
388395

@@ -392,19 +399,19 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
392399
AffineMap invertedMap =
393400
inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
394401
if (!invertedMap) {
395-
return rewriter.notifyMatchFailure(genericOp,
402+
return rewriter.notifyMatchFailure(op,
396403
"invalid indexing maps for operation");
397404
}
398405

399406
SmallVector<int64_t> allShapesSizes;
400-
for (OpOperand &opOperand : genericOp->getOpOperands())
401-
llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand));
407+
for (OpOperand &opOperand : op->getOpOperands())
408+
llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
402409

403410
// 1a. Get the allowed list of dimensions to drop from the `options`.
404-
SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
411+
SmallVector<unsigned> allowedUnitDims = options.controlFn(op);
405412
if (allowedUnitDims.empty()) {
406413
return rewriter.notifyMatchFailure(
407-
genericOp, "control function returns no allowed unit dims to prune");
414+
op, "control function returns no allowed unit dims to prune");
408415
}
409416
llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
410417
allowedUnitDims.end());
@@ -417,19 +424,16 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
417424
}
418425
}
419426

420-
// 2. Compute the iterator types of the modified op by dropping the one-trip
427+
// 2. Compute the new loops of the modified op by dropping the one-trip
421428
// count loops.
422-
SmallVector<utils::IteratorType> newIteratorTypes;
423429
llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
424430
SmallVector<AffineExpr> dimReplacements;
425431
unsigned newDims = 0;
426-
for (auto [index, attr] :
427-
llvm::enumerate(genericOp.getIteratorTypesArray())) {
432+
for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
428433
if (unitDims.count(index)) {
429434
dimReplacements.push_back(
430435
getAffineConstantExpr(0, rewriter.getContext()));
431436
} else {
432-
newIteratorTypes.push_back(attr);
433437
oldDimToNewDimMap[index] = newDims;
434438
dimReplacements.push_back(
435439
getAffineDimExpr(newDims, rewriter.getContext()));
@@ -462,9 +466,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
462466
}
463467
return false;
464468
};
465-
for (OpOperand &opOperand : genericOp->getOpOperands()) {
466-
auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
467-
ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
469+
for (OpOperand &opOperand : op->getOpOperands()) {
470+
auto indexingMap = op.getMatchingIndexingMap(&opOperand);
471+
SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand);
468472
if (!hasCollapsibleType(opOperand)) {
469473
AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
470474
dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
@@ -474,9 +478,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
474478
reassociations.push_back({});
475479
continue;
476480
}
477-
auto replacementInfo = dropUnitExtentFromOperandMetadata(
478-
rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
479-
dimReplacements);
481+
auto replacementInfo =
482+
dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
483+
oldDimToNewDimMap, dimReplacements);
480484
reassociations.push_back(replacementInfo.reassociation);
481485
newIndexingMaps.push_back(replacementInfo.indexMap);
482486
targetShapes.push_back(replacementInfo.targetShape);
@@ -491,13 +495,13 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
491495
concatAffineMaps(newIndexingMaps, rewriter.getContext())))
492496
return failure();
493497

494-
Location loc = genericOp.getLoc();
498+
Location loc = op.getLoc();
495499
// 4. For each of the operands, collapse the operand to convert
496500
// from original shape to shape in the modified operation if needed,
497501
// either through use of reshapes or rank-reducing slices as
498502
// specified in `options`.
499503
SmallVector<Value> newOperands;
500-
for (OpOperand &opOperand : genericOp->getOpOperands()) {
504+
for (OpOperand &opOperand : op->getOpOperands()) {
501505
int64_t idx = opOperand.getOperandNumber();
502506
if (!collapsed[idx]) {
503507
newOperands.push_back(opOperand.get());
@@ -508,31 +512,15 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
508512
options.rankReductionStrategy));
509513
}
510514

511-
// 5. Create the `linalg.generic` operation with the new operands,
512-
// indexing maps, iterator types and result types.
513-
ArrayRef<Value> newInputs =
514-
ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
515-
ArrayRef<Value> newOutputs =
516-
ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
517-
SmallVector<Type> resultTypes;
518-
resultTypes.reserve(genericOp.getNumResults());
519-
for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
520-
resultTypes.push_back(newOutputs[i].getType());
521-
GenericOp replacementOp =
522-
rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
523-
newIndexingMaps, newIteratorTypes);
524-
rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
525-
replacementOp.getRegion().begin());
526-
// 5a. Replace `linalg.index` operations that refer to the dropped unit
527-
// dimensions.
528-
replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
515+
IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
516+
loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
529517

530518
// 6. If any result type changes, insert a reshape/slice to convert from the
531519
// original type to the new type.
532520
SmallVector<Value> resultReplacements;
533-
for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
534-
unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
535-
Value origDest = genericOp.getDpsInitOperand(index)->get();
521+
for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
522+
unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
523+
Value origDest = dpsOp.getDpsInitOperand(index)->get();
536524
if (!collapsed[opOperandIndex]) {
537525
resultReplacements.push_back(result);
538526
continue;
@@ -546,6 +534,51 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
546534
return DropUnitDimsResult{replacementOp, resultReplacements};
547535
}
548536

537+
FailureOr<DropUnitDimsResult>
538+
linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
539+
const ControlDropUnitDims &options) {
540+
541+
DroppedUnitDimsBuilder build =
542+
[](Location loc, OpBuilder &b, IndexingMapOpInterface op,
543+
ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
544+
const llvm::SmallDenseSet<unsigned> &droppedDims)
545+
-> IndexingMapOpInterface {
546+
auto genericOp = cast<GenericOp>(op);
547+
// Compute the iterator types of the modified op by dropping the one-trip
548+
// count loops.
549+
SmallVector<utils::IteratorType> newIteratorTypes;
550+
for (auto [index, attr] :
551+
llvm::enumerate(genericOp.getIteratorTypesArray())) {
552+
if (!droppedDims.count(index))
553+
newIteratorTypes.push_back(attr);
554+
}
555+
556+
// Create the `linalg.generic` operation with the new operands,
557+
// indexing maps, iterator types and result types.
558+
ArrayRef<Value> newInputs =
559+
ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
560+
ArrayRef<Value> newOutputs =
561+
ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
562+
SmallVector<Type> resultTypes;
563+
resultTypes.reserve(genericOp.getNumResults());
564+
for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
565+
resultTypes.push_back(newOutputs[i].getType());
566+
GenericOp replacementOp =
567+
b.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
568+
newIndexingMaps, newIteratorTypes);
569+
b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
570+
replacementOp.getRegion().begin());
571+
// 5a. Replace `linalg.index` operations that refer to the dropped unit
572+
// dimensions.
573+
IRRewriter rewriter(b);
574+
replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
575+
576+
return replacementOp;
577+
};
578+
579+
return dropUnitDims(rewriter, genericOp, build, options);
580+
}
581+
549582
namespace {
550583
struct DropUnitDims : public OpRewritePattern<GenericOp> {
551584
DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},

0 commit comments

Comments
 (0)