Skip to content

Commit 9b2e29a

Browse files
committed
[TOSA] Add SameOperandsAndResultRank to TOSA Ops
This patch adds SameOperandsAndResultRank trait to TOSA operators with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait requiring that all operands and results have matching ranks unless the operand/result is unranked. This also renders the TosaMakeBroadcastable pass unnecessary - but this pass is left in for now just in case it is still used in some flows. The lit test, broadcast.mlir, is removed. This also adds verify of the SameOperandsAndResultRank trait in the TosaInferShapes pass to validate inferred shapes. Signed-off-by: Tai Ly <[email protected]> Change-Id: I27bf16b31f15aa92d42ad5376b8791cf74e4f6ac
1 parent d6315af commit 9b2e29a

File tree

7 files changed

+104
-372
lines changed

7 files changed

+104
-372
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> {
231231
//===----------------------------------------------------------------------===//
232232

233233
class Tosa_Op<string mnemonic, list<Trait> traits = []> :
234-
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
234+
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
235235
TosaResolvableShapeOperands])> {
236236
}
237237

@@ -241,6 +241,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
241241
["inferReturnTypeComponents"]>,
242242
ResultsBroadcastableShape,
243243
TosaElementwiseOperator,
244+
SameOperandsAndResultRank,
244245
Pure])> {
245246
let assemblyFormat =
246247
"operands attr-dict `:` functional-type(operands, results)";

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,32 @@ void propagateShapesInRegion(Region &region, TypeModificationState &state) {
303303
}
304304
}
305305

306+
/// recursively validate tosa ops with SameOperandsAndResultRank trait in region
307+
/// and all nested regions
308+
void validateSameOperandsAndResultRankTrait(Region &region) {
309+
int errs = 0;
310+
for (auto &block : region) {
311+
for (auto &op : block) {
312+
if (!op.getDialect() ||
313+
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
314+
continue;
315+
if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
316+
if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
317+
errs++;
318+
}
319+
}
320+
WhileOp whileOp = dyn_cast<WhileOp>(op);
321+
IfOp ifOp = dyn_cast<IfOp>(op);
322+
if (whileOp || ifOp) {
323+
// recurse into whileOp's regions
324+
for (auto &next : op.getRegions()) {
325+
validateSameOperandsAndResultRankTrait(next);
326+
}
327+
}
328+
}
329+
}
330+
}
331+
306332
/// Pass that performs shape propagation across TOSA operations. This includes
307333
/// migrating to within the regions of if/while operations.
308334
struct TosaInferShapes
@@ -313,6 +339,8 @@ struct TosaInferShapes
313339
TypeModificationState state;
314340
propagateShapesInRegion(func.getBody(), state);
315341
state.commit();
342+
343+
validateSameOperandsAndResultRankTrait(func.getBody());
316344
}
317345
};
318346
} // namespace

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
9494
// CHECK: } -> tensor<f32>
9595
%0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
9696

97+
9798
// CHECK: return [[RESULT]] : tensor<f32>
9899
return %0 : tensor<f32>
99100
}
@@ -104,20 +105,20 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
104105
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)>
105106
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
106107

107-
// CHECK-LABEL: func.func @test_add_0d_broadcast(
108+
// CHECK-LABEL: func.func @test_add_2d_broadcast(
108109
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>,
109-
// CHECK-SAME: %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
110-
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
110+
// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1xf32>) -> tensor<2x1xf32> {
111111
// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
112-
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
112+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
113113
// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
114114
// CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
115115
// CHECK: linalg.yield %[[ADD]] : f32
116116
// CHECK: } -> tensor<2x1xf32>
117117
// CHECK: return %[[RESULT]] : tensor<2x1xf32>
118118
// CHECK: }
119-
func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
120-
%0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
119+
func.func @test_add_2d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x1xf32> {
120+
// tosa element-wise operators now require operands of equal ranks
121+
%0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1x1xf32>) -> tensor<2x1xf32>
121122
return %0 : tensor<2x1xf32>
122123
}
123124

@@ -364,23 +365,9 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
364365

365366
// -----
366367

367-
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
368-
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
369-
// CHECK-LABEL: @test_add_2d_different_ranks
370-
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
371-
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
372368
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
373-
374-
// CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
375-
// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
376-
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
377-
// CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
378-
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
379-
// CHECK: linalg.yield %[[VAL_4]] : f32
380-
// CHECK: } -> tensor<2x3x4xf32>
381-
%0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
382-
383-
// CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
369+
// expected-error@+1 {{'tosa.add' op operands don't have matching ranks}}
370+
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
384371
return %0 : tensor<2x3x4xf32>
385372
}
386373

0 commit comments

Comments
 (0)