Skip to content

Commit 58e214b

Browse files
committed
[mlir][tosa] Enhance the conv2d verifier
This commit adds additional checks to the conv2d verififer that check error_if conditions from the tosa specification. Notably, it adds padding, stride and dilation invalid value checking, output height and width checking and bias size checking. Change-Id: Ic5b2a459587bd781b9c8a55a912eb4b02eeb963d Signed-off-by: Luke Hutton <[email protected]>
1 parent 2a0946b commit 58e214b

File tree

7 files changed

+216
-49
lines changed

7 files changed

+216
-49
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,16 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
214214
}
215215
}
216216

217+
//===----------------------------------------------------------------------===//
218+
// Tosa utilities.
219+
//===----------------------------------------------------------------------===//
220+
221+
std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
222+
if (lhs % rhs != 0)
223+
return std::nullopt;
224+
return lhs / rhs;
225+
}
226+
217227
//===----------------------------------------------------------------------===//
218228
// TOSA Operator Verifiers.
219229
//===----------------------------------------------------------------------===//
@@ -1666,13 +1676,6 @@ LogicalResult tosa::ResizeOp::verify() {
16661676
const int64_t borderY = borderValues[0];
16671677
const int64_t borderX = borderValues[1];
16681678

1669-
auto idivCheck = [](const int64_t lhs,
1670-
const int64_t rhs) -> std::optional<int64_t> {
1671-
if (lhs % rhs != 0)
1672-
return std::nullopt;
1673-
return lhs / rhs;
1674-
};
1675-
16761679
// Don't check with input height that could be broadcast (ih != 1)
16771680
// since Linalg, a consumer of TOSA, expects broadcasting support
16781681
// in resize to be available. Taking the cautious approach for now,
@@ -2012,6 +2015,98 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
20122015
LogicalResult Conv2DOp::verify() {
20132016
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
20142017
return failure();
2018+
2019+
llvm::ArrayRef<int64_t> padding = getPad();
2020+
if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
2021+
return emitOpError("expect all padding values to be >= 0, got ") << padding;
2022+
2023+
llvm::ArrayRef<int64_t> strides = getStride();
2024+
if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
2025+
return emitOpError("expect all stride values to be >= 1, got ") << strides;
2026+
2027+
llvm::ArrayRef<int64_t> dilations = getDilation();
2028+
if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
2029+
return emitOpError("expect all dilation values to be >= 1, got ")
2030+
<< dilations;
2031+
2032+
const RankedTensorType outputType =
2033+
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2034+
if (!outputType)
2035+
// Skip following checks if output is not ranked
2036+
return success();
2037+
2038+
const RankedTensorType inputType =
2039+
llvm::dyn_cast<RankedTensorType>(getInput().getType());
2040+
const RankedTensorType weightType =
2041+
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
2042+
2043+
if (inputType && weightType) {
2044+
const int64_t ih = inputType.getDimSize(1);
2045+
const int64_t kh = weightType.getDimSize(1);
2046+
const int64_t oh = outputType.getDimSize(1);
2047+
2048+
const int64_t pad_top = padding[0];
2049+
const int64_t pad_bottom = padding[1];
2050+
const int64_t stride_y = strides[0];
2051+
const int64_t dilation_y = dilations[0];
2052+
2053+
if (ih != ShapedType::kDynamic) {
2054+
const std::optional<int64_t> calculatedOutHeightMinusOne = idivCheck(
2055+
(ih - 1) + pad_top + pad_bottom - (kh - 1) * dilation_y, stride_y);
2056+
if (!calculatedOutHeightMinusOne.has_value())
2057+
return emitOpError("expected (input_height - 1) + pad_top + pad_bottom "
2058+
"- (kernel_height - 1) * dilation_y ")
2059+
<< "to be wholly divisible by stride_y, got ((" << ih
2060+
<< " - 1) + " << pad_top << " + " << pad_bottom << " - (" << kh
2061+
<< " - 1) * " << dilation_y << ") / " << stride_y;
2062+
const int64_t calculatedOutHeight =
2063+
calculatedOutHeightMinusOne.value() + 1;
2064+
if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2065+
return emitOpError("calculated output height did not match expected: ")
2066+
<< "calculated=" << calculatedOutHeight << ", expected=" << oh;
2067+
}
2068+
2069+
const int64_t iw = inputType.getDimSize(2);
2070+
const int64_t kw = weightType.getDimSize(2);
2071+
const int64_t ow = outputType.getDimSize(2);
2072+
2073+
const int64_t pad_left = padding[2];
2074+
const int64_t pad_right = padding[3];
2075+
const int64_t stride_x = strides[1];
2076+
const int64_t dilation_x = dilations[1];
2077+
2078+
if (iw != ShapedType::kDynamic) {
2079+
const std::optional<int64_t> calculatedOutWidthMinusOne = idivCheck(
2080+
(iw - 1) + pad_left + pad_right - (kw - 1) * dilation_x, stride_x);
2081+
if (!calculatedOutWidthMinusOne.has_value())
2082+
return emitOpError("expected (input_width - 1) + pad_left + pad_right "
2083+
"- (kernel_width - 1) * dilation_x ")
2084+
<< "to be wholly divisible by stride_x, got ((" << iw
2085+
<< " - 1) + " << pad_left << " + " << pad_right << " - (" << kw
2086+
<< " - 1) * " << dilation_x << ") / " << stride_x;
2087+
const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2088+
if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2089+
return emitOpError("calculated output width did not match expected: ")
2090+
<< "calculated=" << calculatedOutWidth << ", expected=" << ow;
2091+
}
2092+
}
2093+
2094+
const RankedTensorType biasType =
2095+
llvm::dyn_cast<RankedTensorType>(getBias().getType());
2096+
if (!biasType)
2097+
// Skip following checks if bias is not ranked
2098+
return success();
2099+
2100+
const int64_t bc = biasType.getDimSize(0);
2101+
const int64_t oc = outputType.getDimSize(3);
2102+
if (bc == ShapedType::kDynamic || oc == ShapedType::kDynamic)
2103+
// Skip following checks if bc or oc is dynamic dim
2104+
return success();
2105+
2106+
if (bc != oc && bc != 1)
2107+
return emitOpError(
2108+
"bias channels expected to be equal to output channels (")
2109+
<< oc << ") or 1, got " << bc;
20152110
return success();
20162111
}
20172112

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -465,16 +465,16 @@ func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tens
465465
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
466466
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
467467
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
468-
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
469-
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
468+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x49x42x28xi32>
469+
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x49x42x28xi32>) {
470470
// CHECK: arith.extsi
471471
// CHECK: linalg.yield
472-
// CHECK: } -> tensor<1x45x40x28xi32>
473-
// CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
474-
// HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
472+
// CHECK: } -> tensor<1x49x42x28xi32>
473+
// CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x49x42x28xi32>) -> tensor<1x49x42x28xi32>
474+
// HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x49x42x28xi32>) -> tensor<1x49x42x28xi32>
475475

476476
%zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
477-
%0 = tosa.conv2d %input, %weights, %bias, %zp, %zp {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x45x40x28xi32>
477+
%0 = tosa.conv2d %input, %weights, %bias, %zp, %zp {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x49x42x28xi32>
478478
return
479479
}
480480

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,23 +201,23 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
201201
// -----
202202

203203
// CHECK-LABEL: @conv2d_stride_2
204-
func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
204+
func.func @conv2d_stride_2(%arg0: tensor<4x11x11x2xf32>) -> tensor<4x6x6x3xf32> {
205205
// CHECK: tosa.conv2d
206206
%weight = "tosa.const"() {value = dense<[[[[1.0, 1.0]]], [[[1.0, 1.0]]], [[[1.0, 1.0]]]]> : tensor<3x1x1x2xf32>} : ()-> tensor<3x1x1x2xf32>
207207
%bias = "tosa.const"() {value = dense<0.0> : tensor<3xf32>} : ()-> tensor<3xf32>
208-
%0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
209-
return %0 : tensor<4x10x10x3xf32>
208+
%0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x11x11x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x6x6x3xf32>
209+
return %0 : tensor<4x6x6x3xf32>
210210
}
211211

212212
// -----
213213

214214
// CHECK-LABEL: @conv2d_weight_2x2
215-
func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
215+
func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x9x9x1xf32> {
216216
// CHECK: tosa.conv2d
217217
%weight = "tosa.const"() {value = dense<[[[[1.0], [1.0]], [[1.0], [1.0]]]]> : tensor<1x2x2x1xf32>} : ()-> tensor<1x2x2x1xf32>
218218
%bias = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : ()-> tensor<1xf32>
219-
%0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32>
220-
return %0 : tensor<4x10x10x1xf32>
219+
%0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x9x9x1xf32>
220+
return %0 : tensor<4x9x9x1xf32>
221221
}
222222

223223
// -----

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,3 +1197,75 @@ func.func @broadcast_resize_bilinear_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x
11971197

11981198
return %resize : tensor<3x4x5x7xi32>
11991199
}
1200+
1201+
// -----
1202+
1203+
func.func @test_conv2d_invalid_padding(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1204+
// expected-error@+1 {{'tosa.conv2d' op expect all padding values to be >= 0, got 0, 0, -1, 0}}
1205+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, -1, 0>, stride = array<i64: 1, 1>, local_bound = true}
1206+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1207+
return %0 : tensor<1x4x4x8xf32>
1208+
}
1209+
1210+
// -----
1211+
1212+
func.func @test_conv2d_invalid_stride(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1213+
// expected-error@+1 {{'tosa.conv2d' op expect all stride values to be >= 1, got 0, 1}}
1214+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 0, 1>, local_bound = true}
1215+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1216+
return %0 : tensor<1x4x4x8xf32>
1217+
}
1218+
1219+
// -----
1220+
1221+
func.func @test_conv2d_invalid_dilation(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1222+
// expected-error@+1 {{'tosa.conv2d' op expect all dilation values to be >= 1, got 1, 0}}
1223+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 0>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
1224+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1225+
return %0 : tensor<1x4x4x8xf32>
1226+
}
1227+
1228+
// -----
1229+
1230+
func.func @test_conv2d_wholly_divisible_height(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1231+
// expected-error@+1 {{'tosa.conv2d' op expected (input_height - 1) + pad_top + pad_bottom - (kernel_height - 1) * dilation_y to be wholly divisible by stride_y, got ((4 - 1) + 0 + 0 - (1 - 1) * 1) / 2}}
1232+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 1>, local_bound = true}
1233+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1234+
return %0 : tensor<1x4x4x8xf32>
1235+
}
1236+
1237+
// -----
1238+
1239+
func.func @test_conv2d_wholly_divisible_width(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1240+
// expected-error@+1 {{'tosa.conv2d' op expected (input_width - 1) + pad_left + pad_right - (kernel_width - 1) * dilation_x to be wholly divisible by stride_x, got ((4 - 1) + 0 + 0 - (1 - 1) * 1) / 2}}
1241+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 2>, local_bound = true}
1242+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1243+
return %0 : tensor<1x4x4x8xf32>
1244+
}
1245+
1246+
// -----
1247+
1248+
func.func @test_conv2d_unexpected_output_height(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x6x4x8xf32> {
1249+
// expected-error@+1 {{'tosa.conv2d' op calculated output height did not match expected: calculated=4, expected=6}}
1250+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
1251+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x6x4x8xf32>
1252+
return %0 : tensor<1x6x4x8xf32>
1253+
}
1254+
1255+
// -----
1256+
1257+
func.func @test_conv2d_unexpected_output_width(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x6x8xf32> {
1258+
// expected-error@+1 {{'tosa.conv2d' op calculated output width did not match expected: calculated=4, expected=6}}
1259+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
1260+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x6x8xf32>
1261+
return %0 : tensor<1x4x6x8xf32>
1262+
}
1263+
1264+
// -----
1265+
1266+
func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<7xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1267+
// expected-error@+1 {{'tosa.conv2d' op bias channels expected to be equal to output channels (8) or 1, got 7}}
1268+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
1269+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1270+
return %0 : tensor<1x4x4x8xf32>
1271+
}

0 commit comments

Comments
 (0)