Skip to content

Commit 3cd352b

Browse files
authored
TosaToLinalg: Prefer to emit identity maps (#386)
When deciding whether to emit a map like `#map = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>` or `#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` for and operand of a linalg.generic when lowering element wise TOSA ops, prefer the latter unless broadcasting of the operand is really needed. This helps later transformations which often require the affine map to be a projected permuatation, which only the latter is.
1 parent 08bb427 commit 3cd352b

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -925,8 +925,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
925925
auto shape = cast<ShapedType>(operand.getType()).getShape();
926926
SmallVector<AffineExpr> affineExprs;
927927
for (auto it : llvm::enumerate(shape)) {
928-
auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
929-
: rewriter.getAffineDimExpr(it.index());
928+
// Prefer producting identity maps whenever possible (i.e. no broadcasting
929+
// needed) because some transforms (like reshape folding)
930+
// do not support affine constant exprs.
931+
bool requiresBroadcast =
932+
(it.value() == 1 && resultType.getDimSize(it.index()) != 1);
933+
auto affineExpr = requiresBroadcast
934+
? rewriter.getAffineConstantExpr(0)
935+
: rewriter.getAffineDimExpr(it.index());
930936
affineExprs.push_back(affineExpr);
931937
}
932938
return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,26 @@ func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: t
250250

251251
// -----
252252

253+
// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
254+
// CHECK-LABEL: @test_add_1d_matching_no_broadcast
255+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
256+
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
257+
func.func @test_add_1d_matching_no_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
258+
259+
// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1xf32>
260+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor<1xf32>) outs(%[[VAL_0]] : tensor<1xf32>) {
261+
// CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
262+
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
263+
// CHECK: linalg.yield %[[VAL_4]] : f32
264+
// CHECK: } -> tensor<1xf32>
265+
%0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
266+
267+
// CHECK: return %[[RESULT]] : tensor<1xf32>
268+
return %0 : tensor<1xf32>
269+
}
270+
271+
// -----
272+
253273
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
254274
// CHECK-LABEL: @test_add_1d_matching_static
255275
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:

0 commit comments

Comments
 (0)