Skip to content

Commit 2191f5a

Browse files
authored
[MLIR][TOSA] Add missing SameOperandsAndResultShape Trait to tosa.cast (#153826)
According to the TOSA spec, tosa.cast is only changing the elementtype, and not the shape of the input tensor Signed-off-by: Rickert, Jonas <[email protected]>
1 parent 958cec0 commit 2191f5a

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2247,7 +2247,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
22472247
//===----------------------------------------------------------------------===//
22482248
// Operator: cast
22492249
//===----------------------------------------------------------------------===//
2250-
def Tosa_CastOp: Tosa_Op<"cast", [Pure,
2250+
def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
22512251
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
22522252
["inferReturnTypeComponents"]>]> {
22532253

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,21 +1304,19 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
13041304
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
13051305
// CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32>
13061306
// CHECK: }
1307-
%0 = "tosa.const"(){ values = dense<116.0>: tensor<f32> }: () -> tensor<f32>
1308-
%1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
1309-
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
1310-
return %2 : tensor<3x600x1200xf32>
1307+
%0 = "tosa.const"(){ values = dense<116.0>: tensor<3x600x1200xf32> }: () -> tensor<3x600x1200xf32>
1308+
%1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
1309+
return %1 : tensor<3x600x1200xf32>
13111310
}
13121311

13131312
// -----
13141313

13151314
// CHECK-LABEL: @do_not_fold_reciprocal_int
13161315
func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
13171316
// CHECK: tosa.reciprocal
1318-
%0 = "tosa.const"(){ values = dense<11>: tensor<i32> }: () -> tensor<i32>
1319-
%1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
1320-
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
1321-
return %2 : tensor<3x600x1200xi32>
1317+
%0 = "tosa.const"(){ values = dense<11>: tensor<3x600x1200xi32> }: () -> tensor<3x600x1200xi32>
1318+
%1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
1319+
return %1 : tensor<3x600x1200xi32>
13221320
}
13231321

13241322
// -----

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66

77
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
88

9+
10+
func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
11+
// expected-error@+1{{'tosa.cast' op requires the same shape for all operands and results}}
12+
%1 = "tosa.cast"(%arg0) : (tensor<i1>) -> tensor<5xi32>
13+
return %1 : tensor<5xi32>
14+
}
15+
16+
// -----
17+
918
func.func @test_const() -> tensor<1xf32> {
1019
// expected-error@+1{{'tosa.const' op expected same attr/result element types}}
1120
%0 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xf32>

0 commit comments

Comments
 (0)