Skip to content

Commit e3c5c4d

Browse files
authored
Merge pull request #612 from Xilinx/jrickert.tosa_same_shape
Add missing SameOperandsAndResultShape Trait to tosa.cast
2 parents 20ae70c + f59f054 commit e3c5c4d

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1856,7 +1856,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
18561856
//===----------------------------------------------------------------------===//
18571857
// Operator: cast
18581858
//===----------------------------------------------------------------------===//
1859-
def Tosa_CastOp: Tosa_Op<"cast", [Pure,
1859+
def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
18601860
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
18611861
["inferReturnTypeComponents"]>]> {
18621862

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,21 +1332,19 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
13321332
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
13331333
// CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32>
13341334
// CHECK: }
1335-
%0 = "tosa.const"(){ value = dense<116.0>: tensor<f32> }: () -> tensor<f32>
1336-
%1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
1337-
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
1338-
return %2 : tensor<3x600x1200xf32>
1335+
%0 = "tosa.const"(){ value = dense<116.0>: tensor<3x600x1200xf32> }: () -> tensor<3x600x1200xf32>
1336+
%1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
1337+
return %1 : tensor<3x600x1200xf32>
13391338
}
13401339

13411340
// -----
13421341

13431342
// CHECK-LABEL: @do_not_fold_reciprocal_int
13441343
func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
13451344
// CHECK: tosa.reciprocal
1346-
%0 = "tosa.const"(){ value = dense<11>: tensor<i32> }: () -> tensor<i32>
1347-
%1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
1348-
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
1349-
return %2 : tensor<3x600x1200xi32>
1345+
%0 = "tosa.const"(){ value = dense<11>: tensor<3x600x1200xi32> }: () -> tensor<3x600x1200xi32>
1346+
%1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
1347+
return %1 : tensor<3x600x1200xi32>
13501348
}
13511349

13521350
// -----

0 commit comments

Comments
 (0)