Skip to content

Commit cdf24b9

Browse files
authored
[Dispatch] Two fixes for CollapseDimensionsPass (iree-org#19598)
iree-org#19113 uncovered some problems with the logic in this pass. Fixes two problems: 1. If a consumer cannot be collapsed, producers can only collapse dimensions not touched by the consumer 2. When updating which consumer loops can be collapsed, the reassociation of the producer must be taken into account since its possible they are not all contiguous (e.g. a transpose on an input). This is the same logic as in `updateFromConsumer` --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 763406f commit cdf24b9

File tree

2 files changed

+182
-73
lines changed

2 files changed

+182
-73
lines changed

compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp

Lines changed: 84 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -314,19 +314,20 @@ class CollapseInfo {
314314
void dump() const;
315315

316316
// Update CollapseInfo to ensure that all dimensions collapsable in `this` are
317-
// also collapsable in `consumerInfo`. This means:
318-
// 1. Any dimension not collapsable in `consumerInfo` should not be
317+
// also collapsable in `otherInfo`. This means:
318+
// 1. Any dimension not collapsable in `otherInfo` should not be
319319
// collapsable in `this`
320320
// 2. For any pair of dimensions in `this`, if they are collapsable in
321-
// `consumerInfo`, they must be collapsable into the same dimension in
322-
// `consumerInfo` to be collapsable into the same dimension in `this`.
321+
// `otherInfo`, they must be collapsable into the same dimension in
322+
// `otherInfo` to be collapsable into the same dimension in `this`.
323323
// Returns true if the operation modified the number of collapsable loops.
324-
bool updateFromConsumer(OpOperand *operand, const CollapseInfo &consumerInfo);
324+
bool updateFromOther(FailureOr<AffineMap> otherToThisMap,
325+
const CollapseInfo &otherInfo);
325326

326-
// Update `collapsableLoops` by subtracting `uncollapsable` and update the
327-
// reassociation indicies accordingly.
328-
// Returns true if the operation modified the number of collapsable loops.
329-
bool updateCollapseViaSubtract(const CollapsableLoopsSet &uncollapsable);
327+
// Update `this` (which is the info for `op`) when either a producer or
328+
// consumer is not collapsible. This is done by considering all the dims
329+
// accessed by other to be uncollapsible.
330+
bool updateFromUncollapsible(Operation *op, OpOperand *operand);
330331

331332
// Get `collapsableLoops` after applying the transformation provided by `map`.
332333
// Note: doesn't modify `collapsableLoops`, the tranformation is applied to a
@@ -460,48 +461,56 @@ CollapseInfo::getTransformedReassociation(AffineMap map) const {
460461
return transformedReassociation;
461462
}
462463

463-
bool CollapseInfo::updateFromConsumer(OpOperand *operand,
464-
const CollapseInfo &consumerInfo) {
465-
FailureOr<AffineMap> consumerToProducerMap =
466-
getConsumerLoopToProducerLoopsMap(*operand);
467-
if (failed(consumerToProducerMap)) {
464+
bool CollapseInfo::updateFromOther(FailureOr<AffineMap> otherToThisMap,
465+
const CollapseInfo &otherInfo) {
466+
if (failed(otherToThisMap)) {
468467
return this->clear();
469468
}
470469

471-
CollapsableLoopsSet consumerCollapsable =
472-
consumerInfo.getTransformedCollapsableLoops(
473-
consumerToProducerMap.value());
470+
CollapsableLoopsSet otherCollapsible =
471+
otherInfo.getTransformedCollapsableLoops(otherToThisMap.value());
474472

475-
SmallVector<ReassociationIndices> consumerReassoc =
476-
consumerInfo.getTransformedReassociation(consumerToProducerMap.value());
473+
SmallVector<ReassociationIndices> otherReassoc =
474+
otherInfo.getTransformedReassociation(otherToThisMap.value());
477475

478476
// Get a map from original index to the index it gets collapsed into
479-
llvm::DenseMap<long, long> consumerCollapseMap;
480-
for (const auto &[idx, indicies] : llvm::enumerate(consumerReassoc)) {
477+
llvm::DenseMap<long, long> otherCollapseMap;
478+
for (const auto &[idx, indicies] : llvm::enumerate(otherReassoc)) {
481479
for (const auto elem : indicies) {
482-
consumerCollapseMap[elem] = idx;
480+
otherCollapseMap[elem] = idx;
483481
}
484482
}
485483

486-
// Remove all collapsable loops in `producer` that are not collapsable in
487-
// `consumer` (set intersect)
488-
bool didChange = collapsableLoops.remove_if(
489-
[&](long elem) -> bool { return !consumerCollapsable.contains(elem); });
484+
// Remove all collapsable loops in `this` that both exist and are not
485+
// collapsable in `other` (set intersect)
486+
bool didChange = collapsableLoops.remove_if([&](long elem) -> bool {
487+
// Exists and is collapsable
488+
if (otherCollapsible.contains(elem)) {
489+
return false;
490+
}
491+
492+
// Does not exist in `other`.
493+
if (!otherToThisMap->isFunctionOfDim(elem)) {
494+
return false;
495+
}
496+
497+
return true;
498+
});
490499

491500
// Now update the reassociation indicies given the updated `collapsableLoops`
492-
// and `consumerCollapsableMap`.
501+
// and `otherCollapsableMap`.
493502
// The idea is to reconstruct the reassociation indicies, and at each index:
494503
// (1) If `index` IS NOT in `collapsableLoops`, split `indicies` and don't add
495504
// `index` to either.
496505
//
497-
// (2) If `index` IS in `collapsableLoops` but `consumerCollapseMap` maps
506+
// (2) If `index` IS in `collapsableLoops` but `otherCollapseMap` maps
498507
// `index` to a different collapsed loop then the other indicies, split
499508
// `indicies` and insert `index` into the new one.
500509
//
501510
// For example:
502-
// producer reassociation = [[0, 1], [2, 3]]
503-
// consumer reassociation = [0, 1, 2, 3]
504-
// then, consumer reassociation gets updated to [[0, 1], [2, 3]] because
511+
// `this` reassociation = [[0, 1], [2, 3]]
512+
// `other` reassociation = [0, 1, 2, 3]
513+
// then, `other` reassociation gets updated to [[0, 1], [2, 3]] because
505514
// [0, 1] and [2, 3] get collapsed into different loops
506515
//
507516
// (3) Otherwise, keep the index
@@ -525,22 +534,25 @@ bool CollapseInfo::updateFromConsumer(OpOperand *operand,
525534
}
526535
newIndicies.clear();
527536
collapseIntoIdx = kUninitialized;
537+
} else if (!otherCollapseMap.contains(index)) {
538+
// (2) `index` does not exist in `other`.
539+
newIndicies.push_back(index);
528540
} else if (collapseIntoIdx == kUninitialized) {
529-
// (2) First occurance of collapsable loop, set collapseIntoIdx.
530-
collapseIntoIdx = consumerCollapseMap.at(index);
541+
// (3) First occurance of collapsable loop, set collapseIntoIdx.
542+
collapseIntoIdx = otherCollapseMap.at(index);
531543
newIndicies.push_back(index);
532-
} else if (consumerCollapseMap.at(index) != collapseIntoIdx) {
533-
// (3) `index` is collapsable but not collapsable into the other loops.
544+
} else if (otherCollapseMap.at(index) != collapseIntoIdx) {
545+
// (4) `index` is collapsable but not collapsable into the other loops.
534546
// So, split them and look for other loops to collapse `index` into.
535547
didChange = true;
536548
if (newIndicies.size() > 1) {
537549
newReassociation.push_back(std::move(newIndicies));
538550
}
539551
newIndicies.clear();
540-
collapseIntoIdx = consumerCollapseMap[index];
552+
collapseIntoIdx = otherCollapseMap[index];
541553
newIndicies.push_back(index);
542554
} else {
543-
// (4) `index` is collapsable and can be collapsed into
555+
// (5) `index` is collapsable and can be collapsed into
544556
// `collapseIntoIndex`.
545557
newIndicies.push_back(index);
546558
}
@@ -554,10 +566,17 @@ bool CollapseInfo::updateFromConsumer(OpOperand *operand,
554566
return didChange;
555567
}
556568

557-
// Update `collapsableLoops` by subtracting `uncollapsable` and update the
558-
// reassociation indicies accordingly.
559-
bool CollapseInfo::updateCollapseViaSubtract(
560-
const CollapsableLoopsSet &uncollapsable) {
569+
bool CollapseInfo::updateFromUncollapsible(Operation *op, OpOperand *operand) {
570+
auto fusionOp = cast<LinalgFusionOpInterface>(op);
571+
AffineMap map = operand->getOwner() == op
572+
? fusionOp.getMatchingIndexingMap(operand)
573+
: fusionOp.getIndexingMapMatchingResult(
574+
cast<OpResult>(operand->get()));
575+
576+
CollapseInfo::CollapsableLoopsSet uncollapsable;
577+
for (auto expr : map.getResults()) {
578+
uncollapsable.insert(cast<AffineDimExpr>(expr).getPosition());
579+
}
561580
auto initialSize = collapsableLoops.size();
562581
collapsableLoops.set_subtract(uncollapsable);
563582
updateReassociation();
@@ -791,35 +810,18 @@ updateConsumersFromProducers(ArrayRef<Operation *> slice,
791810
continue;
792811
}
793812

794-
// Track the dimensions that are not collapsable by this current op.
795-
// Initialize this with all loops in thel producer. Note: the dims are
796-
// relative to the consumers iteration space, not the producers. This
797-
// cannot be done via union of producer and consumer collapsable loops
798-
// because the consumer may have loops that the producer does not.
799-
CollapseInfo::CollapsableLoopsSet producerUncollapsable;
800-
for (auto expr :
801-
consumerOp.getMatchingIndexingMap(operand).getResults()) {
802-
producerUncollapsable.insert(cast<AffineDimExpr>(expr).getPosition());
803-
}
804-
805-
FailureOr<AffineMap> mapping =
806-
getProducerLoopToConsumerLoopsMap(*operand);
807-
808-
// If there is no mapping or we can't find the op, the tensor is
809-
// not collapsable. So, all dimensions of the producer are uncollapsable.
810-
if (!opMap.contains(producerOp) || failed(mapping)) {
811-
didChange |=
812-
consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
813+
// If we can't find the op, the tensor is not collapsable. So, consider
814+
// all the dimensions of the producer to be uncollapsable.
815+
if (!opMap.contains(producerOp)) {
816+
didChange |= consumerInfo.updateFromUncollapsible(consumerOp, operand);
813817
continue;
814818
}
815819

816820
const CollapseInfo &producerInfo = opMap.at(producerOp);
817-
CollapseInfo::CollapsableLoopsSet producerCollapsable =
818-
producerInfo.getTransformedCollapsableLoops(mapping.value());
819-
producerUncollapsable.set_subtract(producerCollapsable);
820-
821+
FailureOr<AffineMap> consumerToProducerMap =
822+
getProducerLoopToConsumerLoopsMap(*operand);
821823
didChange |=
822-
consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
824+
consumerInfo.updateFromOther(consumerToProducerMap, producerInfo);
823825
}
824826
}
825827
return didChange;
@@ -837,21 +839,31 @@ updateProducersFromConsumers(ArrayRef<Operation *> slice,
837839
// Iterate over `slice` in reverse so that we visit each `op` 's consumer
838840
// before visiting `op`.
839841
for (auto op : llvm::reverse(slice)) {
840-
auto consumerOp = cast<DestinationStyleOpInterface>(op);
841-
const CollapseInfo &consumerInfo = opMap.at(consumerOp);
842+
auto producerOp = cast<LinalgFusionOpInterface>(op);
843+
CollapseInfo &producerInfo = opMap.find(producerOp)->second;
842844

843-
for (auto *operand : consumerOp.getDpsInputOperands()) {
844-
auto definingOp = operand->get().getDefiningOp();
845-
if (!definingOp || !opMap.contains(definingOp)) {
845+
for (auto &operand : producerOp->getUses()) {
846+
auto *consumerOp = operand.getOwner();
847+
if (consumerOp->hasTrait<OpTrait::IsTerminator>()) {
848+
continue;
849+
}
850+
851+
// If we can't find the op, the tensor is not collapsable. So, consider
852+
// all the dimensions of the consumer to be uncollapsable.
853+
if (!opMap.contains(consumerOp)) {
854+
didChange |= producerInfo.updateFromUncollapsible(producerOp, &operand);
846855
continue;
847856
}
848857

849858
// Get a mapping from the consumer's iteration space to the producer's.
850-
CollapseInfo &producerInfo = opMap.find(definingOp)->second;
859+
const CollapseInfo &consumerInfo = opMap.at(consumerOp);
851860

852861
// Only loops collapsable in both the consumer and producer may be
853862
// collapsed.
854-
didChange |= producerInfo.updateFromConsumer(operand, consumerInfo);
863+
FailureOr<AffineMap> consumerToProducerMap =
864+
getConsumerLoopToProducerLoopsMap(operand);
865+
didChange |=
866+
producerInfo.updateFromOther(consumerToProducerMap, consumerInfo);
855867
}
856868
}
857869
return didChange;

compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ util.func public @do_not_collapse_cst_in_place(%arg0: tensor<1x1x2304xf32>) {
1616
util.return
1717
}
1818
// CHECK-LABEL: util.func public @do_not_collapse_cst_in_place
19-
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]]]
19+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
2020
// CHECK-DAG: %[[CST:.+]] = arith.constant
2121
// CHECK-DAG: %[[COLLAPSED_ARG0:.+]] = tensor.collapse_shape %[[ARG0]]
2222
// CHECK-DAG: %[[COLLAPSED_CST:.+]] = tensor.collapse_shape %[[CST]]
@@ -656,3 +656,100 @@ util.func public @collapse(%10: tensor<2x32x32x1280xi8>, %11 : tensor<10240x1280
656656
// CHECK: %[[GEN1:.*]] = linalg.generic
657657
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
658658
// CHECK: flow.return %[[GEN1]] : tensor<2048x10240xf16>
659+
660+
// -----
661+
662+
util.func public @update_from_producer(%arg0: tensor<2x1x256x16x16xi8>, %arg1: tensor<2x1x256xf32>) -> tensor<1x256x16x16xi8> {
663+
%cst = arith.constant 0.000000e+00 : f32
664+
%0 = flow.dispatch.region -> (tensor<1x256x16x16xi8>) {
665+
%1 = tensor.empty() : tensor<1x256x16x16xi8>
666+
%2 = tensor.empty() : tensor<1x256x16x16xf32>
667+
%3 = tensor.empty() : tensor<2x1x256x16x16xf32>
668+
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x1x256x16x16xi8>) outs(%3 : tensor<2x1x256x16x16xf32>) {
669+
^bb0(%in: i8, %out: f32):
670+
%8 = arith.extsi %in : i8 to i32
671+
%9 = arith.sitofp %8 : i32 to f32
672+
linalg.yield %9 : f32
673+
} -> tensor<2x1x256x16x16xf32>
674+
%5 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1x256x16x16xf32>) -> tensor<1x256x16x16xf32>
675+
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d1)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%4, %arg1 : tensor<2x1x256x16x16xf32>, tensor<2x1x256xf32>) outs(%5 : tensor<1x256x16x16xf32>) {
676+
^bb0(%in: f32, %in_0: f32, %out: f32):
677+
%8 = arith.mulf %in, %in_0 : f32
678+
%9 = arith.addf %8, %out : f32
679+
linalg.yield %9 : f32
680+
} -> tensor<1x256x16x16xf32>
681+
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : tensor<1x256x16x16xf32>) outs(%1 : tensor<1x256x16x16xi8>) {
682+
^bb0(%in: f32, %out: i8):
683+
%8 = arith.fptosi %in : f32 to i8
684+
linalg.yield %8 : i8
685+
} -> tensor<1x256x16x16xi8>
686+
flow.return %7 : tensor<1x256x16x16xi8>
687+
}
688+
util.return %0 : tensor<1x256x16x16xi8>
689+
}
690+
691+
// CHECK-LABEL: util.func public @update_from_producer
692+
// CHECK: %[[GEN0:.*]] = linalg.generic
693+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
694+
// CHECK: %[[GEN1:.*]] = linalg.generic
695+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
696+
// CHECK-SAME: ins(%[[GEN0]]
697+
// CHECK: %[[GEN2:.*]] = linalg.generic
698+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
699+
// CHECK-SAME: ins(%[[GEN1]]
700+
// CHECK: flow.return %[[GEN2]] : tensor<256x256xi8>
701+
702+
// -----
703+
704+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
705+
util.func public @uncollapsable_consumer(%arg0: tensor<1x1x2304xf32>) {
706+
%cst = arith.constant dense<0.000000e+00> : tensor<1x1x2304xf32>
707+
%0 = tensor.empty() : tensor<1x1x2304xf32>
708+
%1 = flow.dispatch.region -> (tensor<1x1x2304xf32>) {
709+
%2 = tensor.empty() : tensor<1x1x2304xf32>
710+
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst : tensor<1x1x2304xf32>, tensor<1x1x2304xf32>) outs(%2 : tensor<1x1x2304xf32>) {
711+
^bb0(%in: f32, %in_0: f32, %out: f32):
712+
%4 = arith.addf %in, %in_0 : f32
713+
linalg.yield %4 : f32
714+
} -> tensor<1x1x2304xf32>
715+
%10 = util.optimization_barrier %3 : tensor<1x1x2304xf32>
716+
flow.return %3 : tensor<1x1x2304xf32>
717+
}
718+
util.return
719+
}
720+
// CHECK-LABEL: util.func public @uncollapsable_consumer
721+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
722+
// CHECK-DAG: %[[CST:.+]] = arith.constant
723+
// CHECK: %{{.+}} = flow.dispatch.region
724+
// CHECK: %[[RES:.+]] = linalg.generic
725+
// CHECK-SAME: ins(%[[ARG0]], %[[CST]]
726+
// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[RES]]
727+
// CHECK: flow.return %[[RES]]
728+
729+
// -----
730+
731+
#map0 = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)>
732+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
733+
util.func public @uncollapsable_consumer_partial(%arg0: tensor<10x20x30x2304xf32>) {
734+
%cst = arith.constant dense<0.000000e+00> : tensor<10x20x30x2304xf32>
735+
%0 = tensor.empty() : tensor<30x2304xf32>
736+
%1 = flow.dispatch.region -> (tensor<30x2304xf32>) {
737+
%2 = tensor.empty() : tensor<30x2304xf32>
738+
%3 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %cst : tensor<10x20x30x2304xf32>, tensor<10x20x30x2304xf32>) outs(%2 : tensor<30x2304xf32>) {
739+
^bb0(%in: f32, %in_0: f32, %out: f32):
740+
%4 = arith.addf %in, %in_0 : f32
741+
linalg.yield %4 : f32
742+
} -> tensor<30x2304xf32>
743+
%10 = util.optimization_barrier %3 : tensor<30x2304xf32>
744+
flow.return %3 : tensor<30x2304xf32>
745+
}
746+
util.return
747+
}
748+
// CHECK-LABEL: util.func public @uncollapsable_consumer_partial
749+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
750+
// CHECK-DAG: %[[CST:.+]] = arith.constant
751+
// CHECK: %{{.+}} = flow.dispatch.region
752+
// CHECK: %[[RES:.+]] = linalg.generic
753+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
754+
// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[RES]]
755+
// CHECK: flow.return %[[RES]]

0 commit comments

Comments
 (0)