Skip to content

Commit f1e9219

Browse files
dan-garveyIanWood1
andauthored
[DispatchCreation] Fix trailing unit dims case for collapse of expand folding (iree-org#21677)
Previous logic only prevented collapsing all unit dims in a reassociation element if the first element of the reassociation represented a unit dim in the input. i.e. it worked for cases like 1x1x44x5 -> 1x44x5 but failed for 5x44x1x1 -> 5x44x1 --------- Signed-off-by: dan <[email protected]> Signed-off-by: Ian Wood <[email protected]> Co-authored-by: Ian Wood <[email protected]>
1 parent e6fb1e1 commit f1e9219

File tree

2 files changed

+43
-5
lines changed

2 files changed

+43
-5
lines changed

compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,10 @@ struct DropUnitDimsFromCollapseOfExpand
8787
continue;
8888
}
8989

90-
// If we are collapsing multiple unit dims together, at least 1 must be
91-
// kept (prefer the first).
92-
if (outShape[outDim] == 1 && innerIdx != 0) {
90+
// If outShape[outDim] == 1, we must preserve 1 unit dim,
91+
// so we drop the first. If the first is the only unit dim,
92+
// we can't drop it anyway.
93+
if (outShape[outDim] == 1 && innerIdx == 0) {
9394
continue;
9495
}
9596
toDrop.insert(inDim);
@@ -99,8 +100,13 @@ struct DropUnitDimsFromCollapseOfExpand
99100
// Remove dimensions from `toDrop` that weren't introduced by the
100101
// `expandOp` op.
101102
const auto expandReassoc = expandOp.getReassociationIndices();
102-
for (const auto &[inDim, indices] : llvm::enumerate(expandReassoc)) {
103-
if (indices.size() == 1) {
103+
for (const auto &indices : expandReassoc) {
104+
// If all of indices are in `toDrop`, we must preserve at least one
105+
// to avoid an empty reassociation map during expansion.
106+
// This can happen when outShape does not have a unit dimension
107+
// corresponding to the unit dimensions being dropped here.
108+
if (llvm::all_of(indices,
109+
[&](int64_t idx) { return toDrop.contains(idx); })) {
104110
toDrop.erase(indices[0]);
105111
}
106112
}

compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,35 @@ util.func @collapse_of_expand_to_scalar(%arg0: tensor<1x1xf16>, %arg1: index, %a
312312
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
313313
// CHECK-SAME: tensor<1x1xf16> into tensor<f16>
314314
// CHECK: util.return %[[COLLAPSED]] : tensor<f16>
315+
316+
// -----
317+
318+
util.func @collapse_of_expand_trailing_unit_dims(%arg0: tensor<23040x1xbf16>) -> tensor<4x5760xbf16> {
319+
%expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [4, 5760, 1, 1] : tensor<23040x1xbf16> into tensor<4x5760x1x1xbf16>
320+
%collapsed = tensor.collapse_shape %expanded [[0], [1, 2, 3]] : tensor<4x5760x1x1xbf16> into tensor<4x5760xbf16>
321+
util.return %collapsed : tensor<4x5760xbf16>
322+
}
323+
// CHECK-LABEL: util.func public @collapse_of_expand_trailing_unit_dims
324+
// CHECK-SAME: %[[ARG0:.+]]: tensor<23040x1xbf16>
325+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
326+
// CHECK-SAME: tensor<23040x1xbf16> into tensor<4x5760x1xbf16>
327+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
328+
// CHECK-SAME: tensor<4x5760x1xbf16> into tensor<4x5760xbf16>
329+
// CHECK: util.return %[[COLLAPSE]] : tensor<4x5760xbf16>
330+
331+
// -----
332+
333+
// This test considers the case where we have multiple trailing unit dims but must preserve one for the output,
334+
// as well as an isolated unit dim that must be preserved for the collapse's reassociation dims.
335+
util.func @collapse_of_expand_preserved_trailing_unit_dims(%arg0: tensor<1x23040xbf16>) -> tensor<4x5760x1xbf16> {
336+
%expanded = tensor.expand_shape %arg0 [[0], [1, 2, 3, 4, 5]] output_shape [1, 4, 5760, 1, 1, 1] : tensor<1x23040xbf16> into tensor<1x4x5760x1x1x1xbf16>
337+
%collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4, 5]] : tensor<1x4x5760x1x1x1xbf16> into tensor<4x5760x1xbf16>
338+
util.return %collapsed : tensor<4x5760x1xbf16>
339+
}
340+
// CHECK-LABEL: util.func public @collapse_of_expand_preserved_trailing_unit_dims
341+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x23040xbf16>
342+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
343+
// CHECK-SAME: tensor<1x23040xbf16> into tensor<1x4x5760x1xbf16>
344+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
345+
// CHECK-SAME: tensor<1x4x5760x1xbf16> into tensor<4x5760x1xbf16>
346+
// CHECK: util.return %[[COLLAPSE]] : tensor<4x5760x1xbf16>

0 commit comments

Comments
 (0)