diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c3d9d2a773ae7..c3dd3d00e7b8e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1008,6 +1008,18 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { } } + // Fold cast from bf16 -> f32 -> bf16 into no-op. + if (auto cast = getInput().getDefiningOp()) { + auto sourceElTy = cast.getInput().getType().getElementType(); + auto intermediateElTy = cast.getType().getElementType(); + auto finalElTy = getType().getElementType(); + if (isa(sourceElTy) && isa(intermediateElTy) && + isa(finalElTy)) { + getInputMutable().assign(cast.getInput()); + return getResult(); + } + } + auto operand = llvm::dyn_cast_if_present(adaptor.getInput()); if (!operand) return {}; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 2035659f17146..f35df639cca52 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -55,6 +55,14 @@ func.func @cast_fold_double(%arg0: tensor) -> tensor { return %1 : tensor } +// CHECK-LABEL: @cast_fold_double +func.func @cast_fold_double2(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = tosa.cast %arg0 : (tensor) -> tensor + %1 = tosa.cast %0 : (tensor) -> tensor + return %1 : tensor +} + // CHECK-LABEL: @cast_no_fold_double1 func.func @cast_no_fold_double1(%arg0: tensor) -> tensor { // CHECK: tosa.cast{{.*}} (tensor) -> tensor