Skip to content

Commit dca3747

Browse files
authored
[Dispatch Creation] Drop unit dims from tensor.extract ops (#22503)
Adds pattern for `tensor.extract` to fold unit dimensions. Without this pattern there will be reshapes left in the program which may re-introduce the unit dims when propagated. The added xfail already has an issue #20011. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 7667525 commit dca3747

File tree

4 files changed

+171
-6
lines changed

4 files changed

+171
-6
lines changed

compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2727
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2828
#include "mlir/Dialect/Tensor/IR/Tensor.h"
29+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
2930
#include "mlir/Pass/Pass.h"
3031
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3132

@@ -204,11 +205,58 @@ struct DropUnitDimsFromCollapseOfExpand
204205
}
205206
};
206207

207-
} // namespace
208+
// Fold unit dims from `tensor.extract` ops.
209+
struct FoldUnitDimsFromExtractOp : OpRewritePattern<tensor::ExtractOp> {
210+
using Base::Base;
211+
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
212+
PatternRewriter &rewriter) const override {
213+
RankedTensorType srcType = extractOp.getTensor().getType();
214+
if (srcType.getShape().empty() ||
215+
llvm::none_of(srcType.getShape(),
216+
[](int64_t size) { return size == 1; })) {
217+
return failure();
218+
}
219+
SmallVector<Value> oldIndices = extractOp.getIndices();
220+
221+
SmallVector<int64_t> newShape;
222+
SmallVector<Value> newIndices;
223+
SmallVector<ReassociationIndices> reassoc;
224+
ReassociationIndices currReassoc;
225+
226+
// Build reassociation groups where each non-unit dimension forms one output
227+
// dimension, and unit dimensions are grouped with adjacent non-unit dims.
228+
for (auto [idx, size] : llvm::enumerate(srcType.getShape())) {
229+
currReassoc.push_back(idx);
230+
231+
if (size != 1) {
232+
// Non-unit dimension: this forms one output dimension
233+
// Finish current group and start a new one
234+
reassoc.push_back(std::move(currReassoc));
235+
currReassoc.clear();
236+
newShape.push_back(size);
237+
newIndices.push_back(oldIndices[idx]);
238+
}
239+
}
208240

209-
//===----------------------------------------------------------------------===//
210-
// Pass helpers
211-
//===----------------------------------------------------------------------===//
241+
// If we have trailing unit dims, merge them with the last group
242+
if (!currReassoc.empty() && !reassoc.empty()) {
243+
reassoc.back().append(currReassoc.begin(), currReassoc.end());
244+
}
245+
246+
rewriter.setInsertionPointAfterValue(extractOp.getTensor());
247+
auto collapseOp = tensor::CollapseShapeOp::create(
248+
rewriter, extractOp.getLoc(), extractOp.getTensor(), reassoc);
249+
250+
rewriter.setInsertionPointAfter(extractOp);
251+
auto newExtract = tensor::ExtractOp::create(
252+
rewriter, extractOp.getLoc(), extractOp.getResult().getType(),
253+
collapseOp.getResult(), newIndices);
254+
rewriter.replaceOp(extractOp, newExtract);
255+
return success();
256+
}
257+
};
258+
259+
} // namespace
212260

213261
static void
214262
populatefoldUnitDimsPatterns(RewritePatternSet &foldUnitDimsPatterns) {
@@ -230,8 +278,9 @@ populatefoldUnitDimsPatterns(RewritePatternSet &foldUnitDimsPatterns) {
230278
IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns,
231279
options);
232280
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
233-
foldUnitDimsPatterns.insert<DropUnitDimsFromCollapseOfExpand>(
234-
foldUnitDimsPatterns.getContext());
281+
foldUnitDimsPatterns
282+
.insert<DropUnitDimsFromCollapseOfExpand, FoldUnitDimsFromExtractOp>(
283+
foldUnitDimsPatterns.getContext());
235284
}
236285

237286
static LogicalResult

compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,117 @@ util.func @collapse_of_expand_preserved_trailing_unit_dims(%arg0: tensor<1x23040
344344
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
345345
// CHECK-SAME: tensor<1x4x5760x1xbf16> into tensor<4x5760x1xbf16>
346346
// CHECK: util.return %[[COLLAPSE]] : tensor<4x5760x1xbf16>
347+
348+
// -----
349+
350+
util.func @fold_unit_dims_from_extract_leading(%arg0: tensor<1x4x8xf32>, %idx0: index, %idx1: index, %idx2: index) -> f32 {
351+
%extracted = tensor.extract %arg0[%idx0, %idx1, %idx2] : tensor<1x4x8xf32>
352+
util.return %extracted : f32
353+
}
354+
// CHECK-LABEL: util.func public @fold_unit_dims_from_extract_leading
355+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x4x8xf32>
356+
// CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
357+
// CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
358+
// CHECK-SAME: %[[IDX2:[a-zA-Z0-9]+]]: index
359+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2]{{\]}}
360+
// CHECK-SAME: tensor<1x4x8xf32> into tensor<4x8xf32>
361+
// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]][%[[IDX1]], %[[IDX2]]]
362+
// CHECK: util.return %[[EXTRACT]] : f32
363+
364+
// -----
365+
366+
util.func @fold_unit_dims_from_extract_trailing(%arg0: tensor<4x8x1xf32>, %idx0: index, %idx1: index, %idx2: index) -> f32 {
367+
%extracted = tensor.extract %arg0[%idx0, %idx1, %idx2] : tensor<4x8x1xf32>
368+
util.return %extracted : f32
369+
}
370+
// CHECK-LABEL: util.func public @fold_unit_dims_from_extract_trailing
371+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x8x1xf32>
372+
// CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
373+
// CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
374+
// CHECK-SAME: %[[IDX2:[a-zA-Z0-9]+]]: index
375+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]{{\]}}
376+
// CHECK-SAME: tensor<4x8x1xf32> into tensor<4x8xf32>
377+
// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]][%[[IDX0]], %[[IDX1]]]
378+
// CHECK: util.return %[[EXTRACT]] : f32
379+
380+
// -----
381+
382+
util.func @fold_unit_dims_from_extract_middle(%arg0: tensor<4x1x8xf32>, %idx0: index, %idx1: index, %idx2: index) -> f32 {
383+
%extracted = tensor.extract %arg0[%idx0, %idx1, %idx2] : tensor<4x1x8xf32>
384+
util.return %extracted : f32
385+
}
386+
// CHECK-LABEL: util.func public @fold_unit_dims_from_extract_middle
387+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1x8xf32>
388+
// CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
389+
// CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
390+
// CHECK-SAME: %[[IDX2:[a-zA-Z0-9]+]]: index
391+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]{{\]}}
392+
// CHECK-SAME: tensor<4x1x8xf32> into tensor<4x8xf32>
393+
// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]][%[[IDX0]], %[[IDX2]]]
394+
// CHECK: util.return %[[EXTRACT]] : f32
395+
396+
// -----
397+
398+
util.func @fold_unit_dims_from_extract_multiple(%arg0: tensor<1x4x1x8x1xf32>, %idx0: index, %idx1: index, %idx2: index, %idx3: index, %idx4: index) -> f32 {
399+
%extracted = tensor.extract %arg0[%idx0, %idx1, %idx2, %idx3, %idx4] : tensor<1x4x1x8x1xf32>
400+
util.return %extracted : f32
401+
}
402+
// CHECK-LABEL: util.func public @fold_unit_dims_from_extract_multiple
403+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x4x1x8x1xf32>
404+
// CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
405+
// CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
406+
// CHECK-SAME: %[[IDX2:[a-zA-Z0-9]+]]: index
407+
// CHECK-SAME: %[[IDX3:[a-zA-Z0-9]+]]: index
408+
// CHECK-SAME: %[[IDX4:[a-zA-Z0-9]+]]: index
409+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3, 4]{{\]}}
410+
// CHECK-SAME: tensor<1x4x1x8x1xf32> into tensor<4x8xf32>
411+
// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]]
412+
// CHECK: util.return %[[EXTRACT]] : f32
413+
414+
// -----
415+
416+
// Test folding consecutive unit dims from tensor.extract
417+
util.func @fold_unit_dims_from_extract_consecutive(%arg0: tensor<1x1x1x8xf32>, %idx0: index, %idx1: index, %idx2: index, %idx3: index) -> f32 {
418+
%extracted = tensor.extract %arg0[%idx0, %idx1, %idx2, %idx3] : tensor<1x1x1x8xf32>
419+
util.return %extracted : f32
420+
}
421+
// CHECK-LABEL: util.func public @fold_unit_dims_from_extract_consecutive
422+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x1x1x8xf32>
423+
// CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
424+
// CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
425+
// CHECK-SAME: %[[IDX2:[a-zA-Z0-9]+]]: index
426+
// CHECK-SAME: %[[IDX3:[a-zA-Z0-9]+]]: index
427+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3]{{\]}}
428+
// CHECK-SAME: tensor<1x1x1x8xf32> into tensor<8xf32>
429+
// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]][%[[IDX3]]]
430+
// CHECK: util.return %[[EXTRACT]] : f32
431+
432+
// -----
433+
434+
// Test folding unit dims with dynamic dimensions
435+
util.func @fold_unit_dims_from_extract_dynamic(%arg0: tensor<1x?x1xf32>, %idx0: index, %idx1: index, %idx2: index) -> f32 {
436+
%extracted = tensor.extract %arg0[%idx0, %idx1, %idx2] : tensor<1x?x1xf32>
437+
util.return %extracted : f32
438+
}
439+
// CHECK-LABEL: util.func public @fold_unit_dims_from_extract_dynamic
440+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x1xf32>
441+
// CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]: index
442+
// CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]: index
443+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}}
444+
// CHECK-SAME: tensor<1x?x1xf32> into tensor<?xf32>
445+
// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]][%[[IDX1]]]
446+
// CHECK: util.return %[[EXTRACT]] : f32
447+
448+
// -----
449+
450+
util.func @fold_unit_dims_from_extract_all_unit(%arg0: tensor<1x1x1xf32>, %idx0: index, %idx1: index, %idx2: index) -> f32 {
451+
%extracted = tensor.extract %arg0[%idx0, %idx1, %idx2] : tensor<1x1x1xf32>
452+
util.return %extracted : f32
453+
}
454+
// CHECK-LABEL: util.func public @fold_unit_dims_from_extract_all_unit
455+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x1x1xf32>
456+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] []
457+
// CHECK-SAME: tensor<1x1x1xf32> into tensor<f32>
458+
// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[COLLAPSED]]
459+
// CHECK-SAME: tensor<f32>
460+
// CHECK: util.return %[[EXTRACT]] : f32

tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O0.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@
346346
"onnx/node/generated/test_reduce_sum_empty_set_non_reduced_axis_zero",
347347
"onnx/node/generated/test_resize_downsample_scales_cubic_align_corners",
348348
"onnx/node/generated/test_resize_downsample_scales_linear_align_corners",
349+
"onnx/node/generated/test_reversesequence_time",
349350
"onnx/node/generated/test_scan_sum",
350351
"onnx/node/generated/test_sce_mean_weight",
351352
"onnx/node/generated/test_sce_mean_weight_ii",

tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O2.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@
352352
"onnx/node/generated/test_reduce_sum_empty_set_non_reduced_axis_zero",
353353
"onnx/node/generated/test_resize_downsample_scales_cubic_align_corners",
354354
"onnx/node/generated/test_resize_downsample_scales_linear_align_corners",
355+
"onnx/node/generated/test_reversesequence_time",
355356
"onnx/node/generated/test_scan_sum",
356357
"onnx/node/generated/test_sce_mean_weight",
357358
"onnx/node/generated/test_sce_mean_weight_ii",

0 commit comments

Comments
 (0)