Skip to content

Commit 8eff95a

Browse files
authored
Merge pull request #413 from Xilinx/matthias.fold_cast_float
TOSA: fold cast-to-bf16(cast-to-f32(x)) -> cast-to-bf16(x)
2 parents 5d0d0ff + 729187c commit 8eff95a

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,18 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
10081008
}
10091009
}
10101010

1011+
// Fold cast from bf16 -> f32 -> bf16 into no-op.
1012+
if (auto cast = getInput().getDefiningOp<CastOp>()) {
1013+
auto sourceElTy = cast.getInput().getType().getElementType();
1014+
auto intermediateElTy = cast.getType().getElementType();
1015+
auto finalElTy = getType().getElementType();
1016+
if (isa<BFloat16Type>(sourceElTy) && isa<Float32Type>(intermediateElTy) &&
1017+
isa<BFloat16Type>(finalElTy)) {
1018+
getInputMutable().assign(cast.getInput());
1019+
return getResult();
1020+
}
1021+
}
1022+
10111023
auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
10121024
if (!operand)
10131025
return {};

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ func.func @cast_fold_double(%arg0: tensor<?x1xf32>) -> tensor<?x1xi8> {
5555
return %1 : tensor<?x1xi8>
5656
}
5757

58+
// CHECK-LABEL: @cast_fold_double
59+
func.func @cast_fold_double2(%arg0: tensor<?x1xbf16>) -> tensor<?x1xbf16> {
60+
// CHECK: return %arg0
61+
%0 = tosa.cast %arg0 : (tensor<?x1xbf16>) -> tensor<?x1xf32>
62+
%1 = tosa.cast %0 : (tensor<?x1xf32>) -> tensor<?x1xbf16>
63+
return %1 : tensor<?x1xbf16>
64+
}
65+
5866
// CHECK-LABEL: @cast_no_fold_double1
5967
func.func @cast_no_fold_double1(%arg0: tensor<?x1xf32>) -> tensor<?x1xi8> {
6068
// CHECK: tosa.cast{{.*}} (tensor<?x1xf32>) -> tensor<?x1xui16>

0 commit comments

Comments
 (0)