File tree Expand file tree Collapse file tree 2 files changed +20
-0
lines changed
Expand file tree Collapse file tree 2 files changed +20
-0
lines changed Original file line number Diff line number Diff line change @@ -1008,6 +1008,18 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
10081008 }
10091009 }
10101010
1011+ // Fold cast from bf16 -> f32 -> bf16 into no-op.
1012+ if (auto cast = getInput ().getDefiningOp <CastOp>()) {
1013+ auto sourceElTy = cast.getInput ().getType ().getElementType ();
1014+ auto intermediateElTy = cast.getType ().getElementType ();
1015+ auto finalElTy = getType ().getElementType ();
1016+ if (isa<BFloat16Type>(sourceElTy) && isa<Float32Type>(intermediateElTy) &&
1017+ isa<BFloat16Type>(finalElTy)) {
1018+ getInputMutable ().assign (cast.getInput ());
1019+ return getResult ();
1020+ }
1021+ }
1022+
10111023 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput ());
10121024 if (!operand)
10131025 return {};
Original file line number Diff line number Diff line change @@ -55,6 +55,14 @@ func.func @cast_fold_double(%arg0: tensor<?x1xf32>) -> tensor<?x1xi8> {
5555 return %1 : tensor <?x1 xi8 >
5656}
5757
58+ // CHECK-LABEL: @cast_fold_double
59+ func.func @cast_fold_double2 (%arg0: tensor <?x1 xbf16 >) -> tensor <?x1 xbf16 > {
60+ // CHECK: return %arg0
61+ %0 = tosa.cast %arg0 : (tensor <?x1 xbf16 >) -> tensor <?x1 xf32 >
62+ %1 = tosa.cast %0 : (tensor <?x1 xf32 >) -> tensor <?x1 xbf16 >
63+ return %1 : tensor <?x1 xbf16 >
64+ }
65+
5866// CHECK-LABEL: @cast_no_fold_double1
5967func.func @cast_no_fold_double1 (%arg0: tensor <?x1 xf32 >) -> tensor <?x1 xi8 > {
6068 // CHECK: tosa.cast{{.*}} (tensor<?x1xf32>) -> tensor<?x1xui16>
You can’t perform that action at this time.
0 commit comments