diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 0a6f2477560a1..1955eec9964eb 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1827,8 +1827,8 @@ class GenericResizeConverter : public OpRewritePattern { auto resultTy = cast(op.getType()); auto resultETy = resultTy.getElementType(); - bool floatingPointMode = resultETy.isF16() || resultETy.isF32(); - auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type(); + bool floatingPointMode = isa(resultETy); + auto floatTy = resultETy; auto imageH = inputTy.getShape()[1]; auto imageW = inputTy.getShape()[2]; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir index ff2cbbc0b3938..6998aee45b887 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir @@ -12,6 +12,18 @@ func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x // ----- +// CHECK-LABEL: @unary_resize_nearest_bf16 +func.func @unary_resize_nearest_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x1x7xbf16> { + %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4> + %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2> + %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2> + %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xbf16> + // CHECK: return %arg0 + return %resize : tensor<3x1x1x7xbf16> +} + +// ----- + // CHECK-LABEL: @unary_resize_nearest_fp16 func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> { %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4> @@ -36,6 +48,18 @@ func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1 // ----- +// CHECK-LABEL: @unary_resize_bilinear_bf16 +func.func @unary_resize_bilinear_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x1x7xbf16> { + %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4> + %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2> + %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2> + %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xbf16> + // CHECK: return %arg0 + return %resize : tensor<3x1x1x7xbf16> +} + +// ----- + // CHECK-LABEL: @unary_resize_bilinear_fp16 func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> { %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4> @@ -60,6 +84,26 @@ func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7 // ----- +// CHECK-LABEL: @broadcast_resize_nearest_bf16 +func.func @broadcast_resize_nearest_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x5x7xbf16> { + // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 + // CHECK-NEXT{literal}: [[0], [1, 2, 3]] : tensor<3x1x1x7xbf16> into tensor<3x7xbf16> + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x1x5x7xbf16> + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK-SAME: indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xbf16>) outs(%[[EMPTY]] : tensor<3x1x5x7xbf16>) + // CHECK: ^bb0(%[[IN:.+]]: bf16, %[[OUT:.+]]: bf16): + // CHECK: linalg.yield %[[IN]] : bf16 + %scale = tosa.const_shape { values = dense<[2, 1, 3, 1]> : tensor<4xindex> } : () -> !tosa.shape<4> + %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2> + %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2> + %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x5x7xbf16> + + return %resize : tensor<3x1x5x7xbf16> +} + +// ----- + // CHECK-LABEL: @broadcast_resize_nearest_f32 func.func @broadcast_resize_nearest_f32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x5x7xf32> { // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0