Skip to content

Commit b0de4e6

Browse files
authored
[mlir][tosa] Add support for BF16 in tosa.resize legalization (#158616)
Extend the resize linalg legalization functionality with BF16 support and in accordance to the TOSA specification. Signed-off-by: Georgios Pinitas <[email protected]>
1 parent 29b6433 commit b0de4e6

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,8 +1827,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
18271827
auto resultTy = cast<ShapedType>(op.getType());
18281828
auto resultETy = resultTy.getElementType();
18291829

1830-
bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1831-
auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1830+
bool floatingPointMode = isa<FloatType>(resultETy);
1831+
auto floatTy = resultETy;
18321832

18331833
auto imageH = inputTy.getShape()[1];
18341834
auto imageW = inputTy.getShape()[2];

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@ func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x
1212

1313
// -----
1414

15+
// CHECK-LABEL: @unary_resize_nearest_bf16
16+
func.func @unary_resize_nearest_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x1x7xbf16> {
17+
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
18+
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
19+
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
20+
%resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xbf16>
21+
// CHECK: return %arg0
22+
return %resize : tensor<3x1x1x7xbf16>
23+
}
24+
25+
// -----
26+
1527
// CHECK-LABEL: @unary_resize_nearest_fp16
1628
func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
1729
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
@@ -36,6 +48,18 @@ func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1
3648

3749
// -----
3850

51+
// CHECK-LABEL: @unary_resize_bilinear_bf16
52+
func.func @unary_resize_bilinear_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x1x7xbf16> {
53+
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
54+
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
55+
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
56+
%resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xbf16>
57+
// CHECK: return %arg0
58+
return %resize : tensor<3x1x1x7xbf16>
59+
}
60+
61+
// -----
62+
3963
// CHECK-LABEL: @unary_resize_bilinear_fp16
4064
func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
4165
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
@@ -60,6 +84,26 @@ func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7
6084

6185
// -----
6286

87+
// CHECK-LABEL: @broadcast_resize_nearest_bf16
88+
func.func @broadcast_resize_nearest_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x5x7xbf16> {
89+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
90+
// CHECK-NEXT{literal}: [[0], [1, 2, 3]] : tensor<3x1x1x7xbf16> into tensor<3x7xbf16>
91+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x1x5x7xbf16>
92+
// CHECK: %[[GENERIC:.+]] = linalg.generic
93+
// CHECK-SAME: indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
94+
// CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xbf16>) outs(%[[EMPTY]] : tensor<3x1x5x7xbf16>)
95+
// CHECK: ^bb0(%[[IN:.+]]: bf16, %[[OUT:.+]]: bf16):
96+
// CHECK: linalg.yield %[[IN]] : bf16
97+
%scale = tosa.const_shape { values = dense<[2, 1, 3, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
98+
%offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
99+
%border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
100+
%resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x5x7xbf16>
101+
102+
return %resize : tensor<3x1x5x7xbf16>
103+
}
104+
105+
// -----
106+
63107
// CHECK-LABEL: @broadcast_resize_nearest_f32
64108
func.func @broadcast_resize_nearest_f32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x5x7xf32> {
65109
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0

0 commit comments

Comments
 (0)