Skip to content

Commit 729187c

Browse files
committed
Restrict to bf16-f32-bf16
1 parent ff2485b commit 729187c

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,11 +1008,13 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
10081008
}
10091009
}
10101010

1011-
// cast-to-bf16(cast-to-f32(x)) -> cast-to-bf16(x)
1011+
// Fold cast from bf16 -> f32 -> bf16 into no-op.
10121012
if (auto cast = getInput().getDefiningOp<CastOp>()) {
1013+
auto sourceElTy = cast.getInput().getType().getElementType();
10131014
auto intermediateElTy = cast.getType().getElementType();
10141015
auto finalElTy = getType().getElementType();
1015-
if (isa<Float32Type>(intermediateElTy) && isa<BFloat16Type>(finalElTy)) {
1016+
if (isa<BFloat16Type>(sourceElTy) && isa<Float32Type>(intermediateElTy) &&
1017+
isa<BFloat16Type>(finalElTy)) {
10161018
getInputMutable().assign(cast.getInput());
10171019
return getResult();
10181020
}

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ func.func @cast_fold_double(%arg0: tensor<?x1xf32>) -> tensor<?x1xi8> {
5656
}
5757

5858
// CHECK-LABEL: @cast_fold_double
59-
func.func @cast_fold_double2(%arg0: tensor<?x1xf16>) -> tensor<?x1xbf16> {
60-
// CHECK: tosa.cast{{.*}} (tensor<?x1xf16>) -> tensor<?x1xbf16>
61-
%0 = tosa.cast %arg0 : (tensor<?x1xf16>) -> tensor<?x1xf32>
59+
func.func @cast_fold_double2(%arg0: tensor<?x1xbf16>) -> tensor<?x1xbf16> {
60+
// CHECK: return %arg0
61+
%0 = tosa.cast %arg0 : (tensor<?x1xbf16>) -> tensor<?x1xf32>
6262
%1 = tosa.cast %0 : (tensor<?x1xf32>) -> tensor<?x1xbf16>
6363
return %1 : tensor<?x1xbf16>
6464
}

0 commit comments

Comments
 (0)