File tree Expand file tree Collapse file tree 2 files changed +7
-5
lines changed
Expand file tree Collapse file tree 2 files changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -1008,11 +1008,13 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
10081008 }
10091009 }
10101010
1011- // cast-to- bf16(cast-to- f32(x)) -> cast-to- bf16(x)
1011+ // Fold cast from bf16 -> f32 -> bf16 into no-op.
10121012 if (auto cast = getInput ().getDefiningOp <CastOp>()) {
1013+ auto sourceElTy = cast.getInput ().getType ().getElementType ();
10131014 auto intermediateElTy = cast.getType ().getElementType ();
10141015 auto finalElTy = getType ().getElementType ();
1015- if (isa<Float32Type>(intermediateElTy) && isa<BFloat16Type>(finalElTy)) {
1016+ if (isa<BFloat16Type>(sourceElTy) && isa<Float32Type>(intermediateElTy) &&
1017+ isa<BFloat16Type>(finalElTy)) {
10161018 getInputMutable ().assign (cast.getInput ());
10171019 return getResult ();
10181020 }
Original file line number Diff line number Diff line change @@ -56,9 +56,9 @@ func.func @cast_fold_double(%arg0: tensor<?x1xf32>) -> tensor<?x1xi8> {
5656}
5757
5858// CHECK-LABEL: @cast_fold_double
59- func.func @cast_fold_double2 (%arg0: tensor <?x 1 x f16 >) -> tensor <?x1 xbf16 > {
60- // CHECK: tosa.cast{{.*}} (tensor<?x1xf16>) -> tensor<?x1xbf16>
61- %0 = tosa.cast %arg0 : (tensor <?x 1 x f16 >) -> tensor <?x1 xf32 >
59+ func.func @cast_fold_double2 (%arg0: tensor <?x 1 x bf16 >) -> tensor <?x1 xbf16 > {
60+ // CHECK: return %arg0
61+ %0 = tosa.cast %arg0 : (tensor <?x 1 x bf16 >) -> tensor <?x1 xf32 >
6262 %1 = tosa.cast %0 : (tensor <?x1 xf32 >) -> tensor <?x1 xbf16 >
6363 return %1 : tensor <?x1 xbf16 >
6464}
You can’t perform that action at this time.
0 commit comments