Skip to content

Commit 4d0ea08

Browse files
authored
[Global Opt] Raise tensor.extract to input (iree-org#22434)
The PR that added `raiseTensorExtractToInput` (iree-org#14718) only used the `linalg.generic` as an intermediate step when trying to convert to a `tensor` view-like op. However, we should be replace the generic op with the `tensor.extract`-less generic op when it can't be converted to a view-like op. For example, the linalg op in the test case `test_extract_to_transpose` cannot be converted into a view but should be converted to a simple transpose. Signed-off-by: Ian Wood <[email protected]>
1 parent 6c6d175 commit 4d0ea08

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,28 +1016,28 @@ struct RaiseSpecialOpsPass
10161016

10171017
// First walk the IR and try to raise any slice-like generics to tensor.
10181018
IRRewriter rewriter(context);
1019-
funcOp->walk([&](linalg::GenericOp op) {
1020-
linalg::GenericOp linalgOp = op;
1019+
SmallVector<linalg::GenericOp> genericOps;
1020+
funcOp->walk([&](linalg::GenericOp op) { genericOps.push_back(op); });
10211021

1022+
for (linalg::GenericOp linalgOp : genericOps) {
1023+
// Try raising to tensor.extract to an input and create an linalg.generic.
10221024
OpBuilder::InsertionGuard guard(rewriter);
1023-
1024-
// Try raising to tensor.export and create an intermediate linalg.generic.
1025-
rewriter.setInsertionPoint(op);
1025+
rewriter.setInsertionPoint(linalgOp);
10261026
FailureOr<linalg::GenericOp> maybeNewOp =
10271027
raiseTensorExtractToInput(linalgOp, rewriter);
10281028
if (succeeded(maybeNewOp)) {
1029+
rewriter.replaceOp(linalgOp, *maybeNewOp);
10291030
linalgOp = *maybeNewOp;
10301031
}
10311032

10321033
// Try raising to a view-like operation. Replace if the op raising was
10331034
// successful.
1034-
rewriter.setInsertionPoint(op);
10351035
FailureOr<Operation *> maybeRaisedView =
10361036
tryRaiseToView(linalgOp, rewriter);
10371037
if (succeeded(maybeRaisedView)) {
1038-
rewriter.replaceOp(op, *maybeRaisedView);
1038+
rewriter.replaceOp(linalgOp, *maybeRaisedView);
10391039
}
1040-
});
1040+
}
10411041

10421042
// Next run a variety of raising patterns.
10431043
{

compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,26 @@ util.func public @test_slice_middle(%A : tensor<64x64x64xf32>, %B : tensor<64x64
293293

294294
// -----
295295

296+
#map = affine_map<(d0, d1) -> (d0, d1)>
297+
util.func public @test_extract_to_transpose(%A : tensor<64x64xf32>, %B : tensor<64x64xf32>) -> tensor<64x64xf32> {
298+
%c0 = arith.constant 0 : index
299+
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%B : tensor<64x64xf32>) {
300+
^bb0(%out: f32):
301+
%i0 = linalg.index 0 : index
302+
%i1 = linalg.index 1 : index
303+
%extracted = tensor.extract %A[%i1, %i0] : tensor<64x64xf32>
304+
linalg.yield %extracted : f32
305+
} -> tensor<64x64xf32>
306+
util.return %0 : tensor<64x64xf32>
307+
}
308+
309+
// CHECK-LABEL: util.func public @test_extract_to_transpose
310+
// CHECK: %[[RESULT:.+]] = linalg.generic
311+
// CHECK-SAME: affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>
312+
// CHECK: util.return %[[RESULT]]
313+
314+
// -----
315+
296316
util.func public @test_trailing_elementwise(%arg0: tensor<180x320x1xf32>) -> tensor<320xf32> {
297317
%c0 = arith.constant 0 : index
298318
%c179 = arith.constant 179 : index

0 commit comments

Comments
 (0)