Skip to content

Commit 2b30f25

Browse files
authored
[GlobalOpt] Fix transpose propagation for index-semantic ops by interchanging indexing maps (#22248)
Index-semantic ops were previously treated as elementwise in the `SinkTransposeThroughUnaryElementwiseInput` and `BubbleTransposeThroughUnaryElementwiseDpsInit` patterns, which could not correctly update the indexing maps. After this change, ops with index semantics will no longer be incorrectly handled by these patterns. Instead, they will be processed by the `FuseTransposeWithProducerLinalgOp` pattern, which uses `linalg::interchangeGenericOp`. This function already handles index-semantic ops. --------- Signed-off-by: Ziliang Zhang <[email protected]>
1 parent fcae3fc commit 2b30f25

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,10 @@ class SinkTransposeThroughUnaryElementwiseInput
761761
return rewriter.notifyMatchFailure(genericOp, "non-elementwise generic");
762762
}
763763

764+
if (genericOp.hasIndexSemantics()) {
765+
return rewriter.notifyMatchFailure(genericOp, "has index semantics");
766+
}
767+
764768
if (genericOp.getNumDpsInits() != 1) {
765769
return rewriter.notifyMatchFailure(genericOp,
766770
"unimplemented: multiple results");
@@ -865,6 +869,10 @@ class BubbleTransposeThroughUnaryElementwiseDpsInit
865869
return rewriter.notifyMatchFailure(transposeOp, "not elementwise");
866870
}
867871

872+
if (genericOp.hasIndexSemantics()) {
873+
return rewriter.notifyMatchFailure(genericOp, "has index semantics");
874+
}
875+
868876
if (!genericOp->hasOneUse()) {
869877
return rewriter.notifyMatchFailure(transposeOp, "not single user");
870878
}
@@ -898,9 +906,9 @@ class BubbleTransposeThroughUnaryElementwiseDpsInit
898906
SmallVector<AffineMap> indexingMaps = getTransposedIndexingMaps(
899907
genericOp, inputOperand->getOperandNumber(), transposeMap);
900908

901-
// We do not need to update indexing maps because this is a unary
902-
// elementwise op where the input and output maps are the same. Just
903-
// replace the operands with transposed variants.
909+
// We do not need to update indexing maps because this is an elementwise
910+
// op where the input and output maps are the same.
911+
// Just replace the operands with transposed variants.
904912
auto newGenericOp =
905913
mlir::clone(rewriter, genericOp, newInit.getType(), newOperands);
906914
newGenericOp.setIndexingMapsAttr(

compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,46 @@ util.func public @do_not_propagate_to_matmul_in_dispatch(%lhs: tensor<16x16xf32>
254254

255255
// -----
256256

257+
util.func public @propagate_to_gather_like_ops(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<1xi16>) -> tensor<2x3x4x5xf32> {
258+
%cst = arith.constant 0xFF800000 : f32
259+
%empty_transposed = tensor.empty() : tensor<2x4x5x3xf32>
260+
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x4x5xf32>) outs(%empty_transposed : tensor<2x4x5x3xf32>) permutation = [0, 2, 3, 1]
261+
%empty = tensor.empty() : tensor<2x4x5x3xf32>
262+
%collapsed = tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
263+
%mask = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%transposed, %collapsed : tensor<2x4x5x3xf32>, tensor<i16>) outs(%empty : tensor<2x4x5x3xf32>) {
264+
^bb0(%in: f32, %in_0: i16, %out: f32):
265+
%11 = linalg.index 3 : index
266+
%12 = arith.index_cast %in_0 : i16 to index
267+
%13 = arith.cmpi ult, %11, %12 : index
268+
%14 = arith.select %13, %in, %cst : f32
269+
linalg.yield %14 : f32
270+
} -> tensor<2x4x5x3xf32>
271+
%empty_transposed_0 = tensor.empty() : tensor<2x3x4x5xf32>
272+
%transposed_0 = linalg.transpose ins(%mask : tensor<2x4x5x3xf32>) outs(%empty_transposed_0 : tensor<2x3x4x5xf32>) permutation = [0, 3, 1, 2]
273+
util.return %transposed_0 : tensor<2x3x4x5xf32>
274+
}
275+
276+
// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
277+
// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2, d3) -> ()>
278+
// CHECK-LABEL: util.func public @propagate_to_gather_like_ops(
279+
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x4x5xf32>,
280+
// CHECK-SAME: %[[ARG1:.*]]: tensor<1xi16>) -> tensor<2x3x4x5xf32> {
281+
// CHECK: %[[VAL_0:.*]] = arith.constant 0xFF800000 : f32
282+
// CHECK: %[[VAL_1:.*]] = tensor.collapse_shape %[[ARG1]] [] : tensor<1xi16> into tensor<i16>
283+
// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<2x3x4x5xf32>
284+
// CHECK: %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[VAL_1]] : tensor<2x3x4x5xf32>, tensor<i16>) outs(%[[VAL_2]] : tensor<2x3x4x5xf32>) {
285+
// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: i16, %[[VAL_6:.*]]: f32):
286+
// CHECK: %[[VAL_7:.*]] = linalg.index 1 : index
287+
// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_5]] : i16 to index
288+
// CHECK: %[[VAL_9:.*]] = arith.cmpi ult, %[[VAL_7]], %[[VAL_8]] : index
289+
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_9]], %[[VAL_4]], %[[VAL_0]] : f32
290+
// CHECK: linalg.yield %[[VAL_10]] : f32
291+
// CHECK: } -> tensor<2x3x4x5xf32>
292+
// CHECK: util.return %[[VAL_3]] : tensor<2x3x4x5xf32>
293+
// CHECK: }
294+
295+
// -----
296+
257297
util.func public @propagate_to_bmm_transpose_batch(%transposed_lhs: tensor<16x2x16xf32>,
258298
%rhs: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
259299
%empty = tensor.empty(): tensor<2x16x16xf32>

0 commit comments

Comments
 (0)