Skip to content

Commit f5dc573

Browse files
authored
[DispatchCreation] CollapseDimensions patch (#18424)
Fixes the case where parallel and reduction iterators (which are collapsable) are adjacent. They cannot be collapsed into each other in the producer because parallel and reduction dimensions are kept separate.
1 parent a9c7ec1 commit f5dc573

File tree

2 files changed

+172
-46
lines changed

2 files changed

+172
-46
lines changed

compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp

Lines changed: 131 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
#include "iree/compiler/DispatchCreation/Passes.h"
1111
#include "llvm/ADT/STLExtras.h"
1212
#include "llvm/ADT/SetVector.h"
13+
#include "llvm/ADT/SmallVectorExtras.h"
1314
#include "llvm/Support/Debug.h"
15+
#include "llvm/Support/LogicalResult.h"
1416
#include "llvm/Support/raw_ostream.h"
1517
#include "mlir/Analysis/SliceAnalysis.h"
1618
#include "mlir/Dialect/Affine/Utils.h"
@@ -280,10 +282,15 @@ class CollapseInfo {
280282
// Debug print the current operation & reassociation indicies
281283
void dump() const;
282284

283-
// Update `collapsableLoops` by taking the set intersection with
284-
// `otherCollapsable` and update the reassociation indicies accordingly.
285+
// Update CollapseInfo to ensure that all dimensions collapsable in `this` are
286+
// also collapsable in `consumerInfo`. This means:
287+
// 1. Any dimension not collapsable in `consumerInfo` should not be
288+
// collapsable in `this`
289+
// 2. For any pair of dimensions in `this`, if they are collapsable in
290+
// `consumerInfo`, they must be collapsable into the same dimension in
291+
// `consumerInfo` to be collapsable into the same dimension in `this`.
285292
// Returns true if the operation modified the number of collapsable loops.
286-
bool updateCollapseViaIntersect(const CollapsableLoopsSet &otherCollapsable);
293+
bool updateFromConsumer(OpOperand *operand, const CollapseInfo &consumerInfo);
287294

288295
// Update `collapsableLoops` by subtracting `uncollapsable` and update the
289296
// reassociation indicies accordingly.
@@ -293,13 +300,18 @@ class CollapseInfo {
293300
// Get `collapsableLoops` after applying the transformation provided by `map`.
294301
// Note: doesn't modify `collapsableLoops`, the tranformation is applied to a
295302
// copy.
296-
FailureOr<CollapsableLoopsSet>
297-
getTransformedCollapsableLoops(AffineMap map) const;
303+
CollapsableLoopsSet getTransformedCollapsableLoops(AffineMap map) const;
298304

299-
// Clear internal data
300-
void clear() {
305+
// Get `reassociation` after applying the transformation provided by `map`.
306+
SmallVector<ReassociationIndices>
307+
getTransformedReassociation(AffineMap map) const;
308+
309+
// Clear internal data and returns if anything changed.
310+
bool clear() {
311+
bool isNotEmpty = reassociation.empty() || collapsableLoops.empty();
301312
reassociation.clear();
302313
collapsableLoops.clear();
314+
return isNotEmpty;
303315
}
304316

305317
const CollapsableLoopsSet &getCollapsibleLoops() const {
@@ -386,12 +398,8 @@ void CollapseInfo::updateReassociation() {
386398
// map = affine_map<(d0, d1, d2) -> (d1, d2, d5)>
387399
//
388400
// Therefore, the collapsable loops with respect to the consumer is {1, 2, 5}.
389-
FailureOr<CollapseInfo::CollapsableLoopsSet>
401+
CollapseInfo::CollapsableLoopsSet
390402
CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const {
391-
if (!map) {
392-
return failure();
393-
}
394-
395403
CollapsableLoopsSet transformedLoops;
396404
for (auto index : collapsableLoops) {
397405
assert(index < map.getNumResults() && "index has no valid mapping");
@@ -405,19 +413,114 @@ CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const {
405413
return transformedLoops;
406414
}
407415

408-
// Update `collapsableLoops` by taking the set intersection with
409-
// `otherCollapsable` and update the reassociation indicies accordingly.
410-
bool CollapseInfo::updateCollapseViaIntersect(
411-
const CollapsableLoopsSet &otherCollapsable) {
412-
CollapsableLoopsSet toRemove;
413-
for (auto elem : collapsableLoops) {
414-
if (!otherCollapsable.contains(elem)) {
415-
toRemove.insert(elem);
416+
SmallVector<ReassociationIndices>
417+
CollapseInfo::getTransformedReassociation(AffineMap map) const {
418+
SmallVector<ReassociationIndices> transformedReassociation(
419+
reassociation.size());
420+
for (const auto &[i, indicies] : llvm::enumerate(reassociation)) {
421+
for (auto elem : indicies) {
422+
auto dimExpr = dyn_cast<AffineDimExpr>(map.getResult(elem));
423+
if (!dimExpr) {
424+
break;
425+
}
426+
transformedReassociation[i].push_back(dimExpr.getPosition());
416427
}
417428
}
418-
collapsableLoops.set_subtract(toRemove);
419-
updateReassociation();
420-
return toRemove.size();
429+
return transformedReassociation;
430+
}
431+
432+
bool CollapseInfo::updateFromConsumer(OpOperand *operand,
433+
const CollapseInfo &consumerInfo) {
434+
FailureOr<AffineMap> consumerToProducerMap =
435+
getConsumerLoopToProducerLoopsMap(*operand);
436+
if (failed(consumerToProducerMap)) {
437+
return this->clear();
438+
}
439+
440+
CollapsableLoopsSet consumerCollapsable =
441+
consumerInfo.getTransformedCollapsableLoops(
442+
consumerToProducerMap.value());
443+
444+
SmallVector<ReassociationIndices> consumerReassoc =
445+
consumerInfo.getTransformedReassociation(consumerToProducerMap.value());
446+
447+
// Get a map from original index to the index it gets collapsed into
448+
llvm::DenseMap<long, long> consumerCollapseMap;
449+
for (const auto &[idx, indicies] : llvm::enumerate(consumerReassoc)) {
450+
for (const auto elem : indicies) {
451+
consumerCollapseMap[elem] = idx;
452+
}
453+
}
454+
455+
// Remove all collapsable loops in `producer` that are not collapsable in
456+
// `consumer` (set intersect)
457+
bool didChange = collapsableLoops.remove_if(
458+
[&](long elem) -> bool { return !consumerCollapsable.contains(elem); });
459+
460+
// Now update the reassociation indicies given the updated `collapsableLoops`
461+
// and `consumerCollapsableMap`.
462+
// The idea is to reconstruct the reassociation indicies, and at each index:
463+
// (1) If `index` IS NOT in `collapsableLoops`, split `indicies` and don't add
464+
// `index` to either.
465+
//
466+
// (2) If `index` IS in `collapsableLoops` but `consumerCollapseMap` maps
467+
// `index` to a different collapsed loop then the other indicies, split
468+
// `indicies` and insert `index` into the new one.
469+
//
470+
// For example:
471+
// producer reassociation = [[0, 1], [2, 3]]
472+
// consumer reassociation = [0, 1, 2, 3]
473+
// then, consumer reassociation gets updated to [[0, 1], [2, 3]] because
474+
// [0, 1] and [2, 3] get collapsed into different loops
475+
//
476+
// (3) Otherwise, keep the index
477+
constexpr long kUninitialized = -1;
478+
SmallVector<ReassociationIndices> newReassociation;
479+
for (ReassociationIndicesRef indicies : reassociation) {
480+
// Track the loop index that `indicies` get collapsed into.
481+
long collapseIntoIdx = kUninitialized;
482+
483+
// Holds dimensions that should be collapsed together
484+
ReassociationIndices newIndicies;
485+
for (int64_t index : indicies) {
486+
if (!collapsableLoops.contains(index)) {
487+
// (1) Because `index` isn't collapsable, the indicies in `newIndicies`
488+
// are no longer adjacent to the upcoming indicies. If there is >1 index
489+
// to collapse, add it to the new reassociation. Otherwise, discard it
490+
// because there is no dimension to collapse with.
491+
didChange = true;
492+
if (newIndicies.size() > 1) {
493+
newReassociation.push_back(std::move(newIndicies));
494+
}
495+
newIndicies.clear();
496+
collapseIntoIdx = kUninitialized;
497+
} else if (collapseIntoIdx == kUninitialized) {
498+
// (2) First occurance of collapsable loop, set collapseIntoIdx.
499+
collapseIntoIdx = consumerCollapseMap.at(index);
500+
newIndicies.push_back(index);
501+
} else if (consumerCollapseMap.at(index) != collapseIntoIdx) {
502+
// (3) `index` is collapsable but not collapsable into the other loops.
503+
// So, split them and look for other loops to collapse `index` into.
504+
didChange = true;
505+
if (newIndicies.size() > 1) {
506+
newReassociation.push_back(std::move(newIndicies));
507+
}
508+
newIndicies.clear();
509+
collapseIntoIdx = consumerCollapseMap[index];
510+
newIndicies.push_back(index);
511+
} else {
512+
// (4) `index` is collapsable and can be collapsed into
513+
// `collapseIntoIndex`.
514+
newIndicies.push_back(index);
515+
}
516+
}
517+
518+
if (newIndicies.size() > 1) {
519+
newReassociation.push_back(newIndicies);
520+
}
521+
}
522+
reassociation = std::move(newReassociation);
523+
return didChange;
421524
}
422525

423526
// Update `collapsableLoops` by subtracting `uncollapsable` and update the
@@ -679,12 +782,10 @@ static bool updateConsumersFromProducers(
679782
continue;
680783
}
681784

682-
CollapseInfo &producerInfo = opMap.find(producerOp)->second;
683-
FailureOr<CollapseInfo::CollapsableLoopsSet> producerCollapsable =
785+
const CollapseInfo &producerInfo = opMap.at(producerOp);
786+
CollapseInfo::CollapsableLoopsSet producerCollapsable =
684787
producerInfo.getTransformedCollapsableLoops(mapping.value());
685-
if (!failed(producerCollapsable)) {
686-
producerUncollapsable.set_subtract(producerCollapsable.value());
687-
}
788+
producerUncollapsable.set_subtract(producerCollapsable);
688789

689790
didChange |=
690791
consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
@@ -707,7 +808,7 @@ static bool updateProducersFromConsumers(
707808
for (auto op : llvm::reverse(slice)) {
708809
auto genericConsumer = cast<linalg::GenericOp>(op);
709810
assert(opMap.contains(genericConsumer));
710-
const CollapseInfo &consumerInfo = opMap.find(genericConsumer)->second;
811+
const CollapseInfo &consumerInfo = opMap.at(genericConsumer);
711812

712813
for (auto operand : genericConsumer.getDpsInputOperands()) {
713814
auto definingOp = operand->get().getDefiningOp();
@@ -721,26 +822,10 @@ static bool updateProducersFromConsumers(
721822

722823
// Get a mapping from the consumer's iteration space to the producer's.
723824
CollapseInfo &producerInfo = opMap.find(genericProducer)->second;
724-
FailureOr<AffineMap> consumerToProducerMap =
725-
getConsumerLoopToProducerLoopsMap(*operand);
726-
if (failed(consumerToProducerMap)) {
727-
didChange |= !producerInfo.getCollapsibleLoops().empty();
728-
producerInfo.clear();
729-
continue;
730-
}
731825

732-
// Use the map to get the consumer's collapsable loops in terms of the
733-
// producer.
734-
auto consumerCollapsable = consumerInfo.getTransformedCollapsableLoops(
735-
consumerToProducerMap.value());
736-
if (failed(consumerCollapsable)) {
737-
producerInfo.clear();
738-
continue;
739-
}
740826
// Only loops collapsable in both the consumer and producer may be
741827
// collapsed.
742-
didChange |=
743-
producerInfo.updateCollapseViaIntersect(consumerCollapsable.value());
828+
didChange |= producerInfo.updateFromConsumer(operand, consumerInfo);
744829
}
745830
}
746831
return didChange;

compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,3 +518,44 @@ util.func public @propagate_uncollapsable(%arg0: tensor<2x320x128x128xf32>) -> t
518518
// CHECK-SAME: ins(%[[VAL2]], %[[VAL1]] : tensor<2x320x128x128xf32>, tensor<2x320x128x128xf32>)
519519
// CHECK-SAME: outs(%{{.*}} : tensor<2x320x128x128xf32>)
520520
// CHECK: flow.return %[[VAL3]]
521+
522+
// -----
523+
524+
util.func public @dequant_contraction(%arg0: tensor<2x32xf32>, %arg1: tensor<2x32x10x16384xf16>) -> tensor<2x32xf32> {
525+
%0 = flow.dispatch.region -> (tensor<2x32xf32>) {
526+
%1 = tensor.empty() : tensor<2x32xf32>
527+
%cst = arith.constant 0.000000e+00 : f32
528+
%2 = tensor.empty() : tensor<2x32x10x16384xf32>
529+
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x32x10x16384xf16>) outs(%2 : tensor<2x32x10x16384xf32>) {
530+
^bb0(%in: f16, %out: f32):
531+
%6 = arith.extf %in : f16 to f32
532+
linalg.yield %6 : f32
533+
} -> tensor<2x32x10x16384xf32>
534+
%4 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32>
535+
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3, %arg0 : tensor<2x32x10x16384xf32>, tensor<2x32xf32>) outs(%4 : tensor<2x32xf32>) {
536+
^bb0(%in: f32, %in_0: f32, %out: f32):
537+
%6 = arith.subf %in, %in_0 : f32
538+
%7 = arith.mulf %6, %6 : f32
539+
%8 = arith.addf %7, %out : f32
540+
linalg.yield %8 : f32
541+
} -> tensor<2x32xf32>
542+
flow.return %5 : tensor<2x32xf32>
543+
}
544+
util.return %0 : tensor<2x32xf32>
545+
}
546+
547+
// CHECK-LABEL: util.func public @dequant_contraction
548+
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x32xf32>
549+
// CHECK-SAME: %[[ARG1:.+]]: tensor<2x32x10x16384xf16>
550+
// CHECK-DAG: %[[COLLAPSED_ARG0:.+]] = tensor.collapse_shape %[[ARG0]]
551+
// CHECK-DAG: %[[COLLAPSED_ARG1:.+]] = tensor.collapse_shape %[[ARG1]]
552+
// CHECK: flow.dispatch.region
553+
// CHECK: %[[VAL0:.*]] = linalg.generic
554+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
555+
// CHECK-SAME: ins(%[[COLLAPSED_ARG1]] : tensor<64x163840xf16>)
556+
// CHECK-SAME: outs(%{{.*}} : tensor<64x163840xf32>)
557+
// CHECK: %[[VAL1:.*]] = linalg.generic
558+
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
559+
// CHECK-SAME: ins(%[[VAL0]], %[[COLLAPSED_ARG0]] : tensor<64x163840xf32>, tensor<64xf32>)
560+
// CHECK-SAME: outs(%{{.*}} : tensor<64xf32>)
561+
// CHECK: flow.return %[[VAL1]]

0 commit comments

Comments
 (0)