File tree Expand file tree Collapse file tree 2 files changed +18
-0
lines changed
Expand file tree Collapse file tree 2 files changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -1008,6 +1008,16 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
10081008 }
10091009 }
10101010
1011+ // cast-to-bf16(cast-to-f32(x)) -> cast-to-bf16(x)
1012+ if (auto cast = getInput ().getDefiningOp <CastOp>()) {
1013+ auto intermediateElTy = cast.getType ().getElementType ();
1014+ auto finalElTy = getType ().getElementType ();
1015+ if (isa<Float32Type>(intermediateElTy) && isa<BFloat16Type>(finalElTy)) {
1016+ getInputMutable ().assign (cast.getInput ());
1017+ return getResult ();
1018+ }
1019+ }
1020+
10111021 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput ());
10121022 if (!operand)
10131023 return {};
Original file line number Diff line number Diff line change @@ -55,6 +55,14 @@ func.func @cast_fold_double(%arg0: tensor<?x1xf32>) -> tensor<?x1xi8> {
5555 return %1 : tensor <?x1 xi8 >
5656}
5757
58+ // CHECK-LABEL: @cast_fold_double
59+ func.func @cast_fold_double2 (%arg0: tensor <?x1 xf16 >) -> tensor <?x1 xbf16 > {
60+ // CHECK: tosa.cast{{.*}} (tensor<?x1xf16>) -> tensor<?x1xbf16>
61+ %0 = tosa.cast %arg0 : (tensor <?x1 xf16 >) -> tensor <?x1 xf32 >
62+ %1 = tosa.cast %0 : (tensor <?x1 xf32 >) -> tensor <?x1 xbf16 >
63+ return %1 : tensor <?x1 xbf16 >
64+ }
65+
5866// CHECK-LABEL: @cast_no_fold_double1
5967func.func @cast_no_fold_double1 (%arg0: tensor <?x1 xf32 >) -> tensor <?x1 xi8 > {
6068 // CHECK: tosa.cast{{.*}} (tensor<?x1xf32>) -> tensor<?x1xui16>
You can’t perform that action at this time.
0 commit comments