Skip to content

Commit 0ba2000

Browse files
authored
[mlir][tosa] Enhance the conv2d verifier (#128693)
This commit adds additional checks to the conv2d verifier that check error_if conditions from the tosa specification. Notably, it adds padding, stride and dilation invalid value checking, output height and width checking and bias size checking. Signed-off-by: Luke Hutton <[email protected]>
1 parent a98c294 commit 0ba2000

File tree

7 files changed

+215
-49
lines changed

7 files changed

+215
-49
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,16 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
214214
}
215215
}
216216

217+
//===----------------------------------------------------------------------===//
218+
// Tosa utilities.
219+
//===----------------------------------------------------------------------===//
220+
221+
std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
222+
if (lhs % rhs != 0)
223+
return std::nullopt;
224+
return lhs / rhs;
225+
}
226+
217227
//===----------------------------------------------------------------------===//
218228
// TOSA Operator Verifiers.
219229
//===----------------------------------------------------------------------===//
@@ -1621,13 +1631,6 @@ LogicalResult tosa::ResizeOp::verify() {
16211631
const int64_t borderY = borderValues[0];
16221632
const int64_t borderX = borderValues[1];
16231633

1624-
auto idivCheck = [](const int64_t lhs,
1625-
const int64_t rhs) -> std::optional<int64_t> {
1626-
if (lhs % rhs != 0)
1627-
return std::nullopt;
1628-
return lhs / rhs;
1629-
};
1630-
16311634
// Don't check with input height that could be broadcast (ih != 1)
16321635
// since Linalg, a consumer of TOSA, expects broadcasting support
16331636
// in resize to be available. Taking the cautious approach for now,
@@ -1967,6 +1970,97 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
19671970
LogicalResult Conv2DOp::verify() {
19681971
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
19691972
return failure();
1973+
1974+
llvm::ArrayRef<int64_t> padding = getPad();
1975+
if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1976+
return emitOpError("expect all padding values to be >= 0, got ") << padding;
1977+
1978+
llvm::ArrayRef<int64_t> strides = getStride();
1979+
if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1980+
return emitOpError("expect all stride values to be >= 1, got ") << strides;
1981+
1982+
llvm::ArrayRef<int64_t> dilations = getDilation();
1983+
if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
1984+
return emitOpError("expect all dilation values to be >= 1, got ")
1985+
<< dilations;
1986+
1987+
const RankedTensorType outputType =
1988+
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1989+
if (!outputType)
1990+
// Skip following checks if output is not ranked
1991+
return success();
1992+
1993+
const RankedTensorType inputType =
1994+
llvm::dyn_cast<RankedTensorType>(getInput().getType());
1995+
const RankedTensorType weightType =
1996+
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
1997+
1998+
if (inputType && weightType) {
1999+
const auto verifyOutputSize =
2000+
[this](const int64_t inputSize, const int64_t kernelSize,
2001+
const int64_t outputSize, const int64_t padBefore,
2002+
const int64_t padAfter, const int64_t stride,
2003+
const int64_t dilation, const llvm::StringRef dimName,
2004+
const llvm::StringRef dimAxis,
2005+
const llvm::StringRef padBeforeName,
2006+
const llvm::StringRef padAfterName) -> LogicalResult {
2007+
if (inputSize == ShapedType::kDynamic ||
2008+
kernelSize == ShapedType::kDynamic)
2009+
return success();
2010+
2011+
const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
2012+
inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
2013+
stride);
2014+
if (!calculatedOutSizeMinusOne.has_value())
2015+
return emitOpError("expected input_")
2016+
<< dimName << " - 1 + pad_" << padBeforeName << " + pad_"
2017+
<< padAfterName << " - (kernel_" << dimName
2018+
<< " - 1) * dilation_" << dimAxis
2019+
<< " to be wholly divisible by stride_" << dimAxis << ", got ("
2020+
<< inputSize << " - 1 + " << padBefore << " + " << padAfter
2021+
<< " - (" << kernelSize << " - 1) * " << dilation << ") / "
2022+
<< stride;
2023+
2024+
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
2025+
if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2026+
return emitOpError("calculated output ")
2027+
<< dimName << " did not match expected: "
2028+
<< "calculated=" << calculatedOutSize
2029+
<< ", expected=" << outputSize;
2030+
2031+
return success();
2032+
};
2033+
2034+
if (failed(verifyOutputSize(
2035+
inputType.getDimSize(1), weightType.getDimSize(1),
2036+
outputType.getDimSize(1), padding[0], padding[1], strides[0],
2037+
dilations[0], "height", "y", "top", "bottom")))
2038+
return failure();
2039+
2040+
if (failed(verifyOutputSize(
2041+
inputType.getDimSize(2), weightType.getDimSize(2),
2042+
outputType.getDimSize(2), padding[2], padding[3], strides[1],
2043+
dilations[1], "width", "x", "left", "right")))
2044+
return failure();
2045+
}
2046+
2047+
const RankedTensorType biasType =
2048+
llvm::dyn_cast<RankedTensorType>(getBias().getType());
2049+
if (!biasType)
2050+
// Skip following checks if bias is not ranked
2051+
return success();
2052+
2053+
const int64_t biasChannels = biasType.getDimSize(0);
2054+
const int64_t outputChannels = outputType.getDimSize(3);
2055+
if (biasChannels == ShapedType::kDynamic ||
2056+
outputChannels == ShapedType::kDynamic)
2057+
// Skip following checks if biasChannels or outputChannels is dynamic dim
2058+
return success();
2059+
2060+
if (biasChannels != outputChannels && biasChannels != 1)
2061+
return emitOpError(
2062+
"bias channels expected to be equal to output channels (")
2063+
<< outputChannels << ") or 1, got " << biasChannels;
19702064
return success();
19712065
}
19722066

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,16 +464,16 @@ func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tens
464464
// CHECK-LABEL: @conv2d_i8
465465
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
466466
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
467-
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
468-
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
467+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x49x42x28xi32>
468+
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x49x42x28xi32>) {
469469
// CHECK: arith.extsi
470470
// CHECK: linalg.yield
471-
// CHECK: } -> tensor<1x45x40x28xi32>
472-
// CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
473-
// HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
471+
// CHECK: } -> tensor<1x49x42x28xi32>
472+
// CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x49x42x28xi32>) -> tensor<1x49x42x28xi32>
473+
// HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x49x42x28xi32>) -> tensor<1x49x42x28xi32>
474474

475475
%zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
476-
%0 = tosa.conv2d %input, %weights, %bias, %zp, %zp {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x45x40x28xi32>
476+
%0 = tosa.conv2d %input, %weights, %bias, %zp, %zp {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x49x42x28xi32>
477477
return
478478
}
479479

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,23 +201,23 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
201201
// -----
202202

203203
// CHECK-LABEL: @conv2d_stride_2
204-
func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
204+
func.func @conv2d_stride_2(%arg0: tensor<4x11x11x2xf32>) -> tensor<4x6x6x3xf32> {
205205
// CHECK: tosa.conv2d
206206
%weight = "tosa.const"() {value = dense<[[[[1.0, 1.0]]], [[[1.0, 1.0]]], [[[1.0, 1.0]]]]> : tensor<3x1x1x2xf32>} : ()-> tensor<3x1x1x2xf32>
207207
%bias = "tosa.const"() {value = dense<0.0> : tensor<3xf32>} : ()-> tensor<3xf32>
208-
%0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
209-
return %0 : tensor<4x10x10x3xf32>
208+
%0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>, dilation = array<i64: 1, 1>} : (tensor<4x11x11x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x6x6x3xf32>
209+
return %0 : tensor<4x6x6x3xf32>
210210
}
211211

212212
// -----
213213

214214
// CHECK-LABEL: @conv2d_weight_2x2
215-
func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
215+
func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x9x9x1xf32> {
216216
// CHECK: tosa.conv2d
217217
%weight = "tosa.const"() {value = dense<[[[[1.0], [1.0]], [[1.0], [1.0]]]]> : tensor<1x2x2x1xf32>} : ()-> tensor<1x2x2x1xf32>
218218
%bias = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : ()-> tensor<1xf32>
219-
%0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32>
220-
return %0 : tensor<4x10x10x1xf32>
219+
%0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x9x9x1xf32>
220+
return %0 : tensor<4x9x9x1xf32>
221221
}
222222

223223
// -----

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,3 +1171,75 @@ func.func @broadcast_resize_bilinear_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x
11711171

11721172
return %resize : tensor<3x4x5x7xi32>
11731173
}
1174+
1175+
// -----
1176+
1177+
func.func @test_conv2d_invalid_padding(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1178+
// expected-error@+1 {{'tosa.conv2d' op expect all padding values to be >= 0, got 0, 0, -1, 0}}
1179+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, -1, 0>, stride = array<i64: 1, 1>, local_bound = true}
1180+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1181+
return %0 : tensor<1x4x4x8xf32>
1182+
}
1183+
1184+
// -----
1185+
1186+
func.func @test_conv2d_invalid_stride(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1187+
// expected-error@+1 {{'tosa.conv2d' op expect all stride values to be >= 1, got 0, 1}}
1188+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 0, 1>, local_bound = true}
1189+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1190+
return %0 : tensor<1x4x4x8xf32>
1191+
}
1192+
1193+
// -----
1194+
1195+
func.func @test_conv2d_invalid_dilation(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1196+
// expected-error@+1 {{'tosa.conv2d' op expect all dilation values to be >= 1, got 1, 0}}
1197+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 0>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
1198+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1199+
return %0 : tensor<1x4x4x8xf32>
1200+
}
1201+
1202+
// -----
1203+
1204+
func.func @test_conv2d_wholly_divisible_height(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1205+
// expected-error@+1 {{'tosa.conv2d' op expected input_height - 1 + pad_top + pad_bottom - (kernel_height - 1) * dilation_y to be wholly divisible by stride_y, got (4 - 1 + 0 + 0 - (1 - 1) * 1) / 2}}
1206+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 1>, local_bound = true}
1207+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1208+
return %0 : tensor<1x4x4x8xf32>
1209+
}
1210+
1211+
// -----
1212+
1213+
func.func @test_conv2d_wholly_divisible_width(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1214+
// expected-error@+1 {{'tosa.conv2d' op expected input_width - 1 + pad_left + pad_right - (kernel_width - 1) * dilation_x to be wholly divisible by stride_x, got (4 - 1 + 0 + 0 - (1 - 1) * 1) / 2}}
1215+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 2>, local_bound = true}
1216+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1217+
return %0 : tensor<1x4x4x8xf32>
1218+
}
1219+
1220+
// -----
1221+
1222+
func.func @test_conv2d_unexpected_output_height(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x6x4x8xf32> {
1223+
// expected-error@+1 {{'tosa.conv2d' op calculated output height did not match expected: calculated=4, expected=6}}
1224+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
1225+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x6x4x8xf32>
1226+
return %0 : tensor<1x6x4x8xf32>
1227+
}
1228+
1229+
// -----
1230+
1231+
func.func @test_conv2d_unexpected_output_width(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x6x8xf32> {
1232+
// expected-error@+1 {{'tosa.conv2d' op calculated output width did not match expected: calculated=4, expected=6}}
1233+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
1234+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x6x8xf32>
1235+
return %0 : tensor<1x4x6x8xf32>
1236+
}
1237+
1238+
// -----
1239+
1240+
func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<7xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
1241+
// expected-error@+1 {{'tosa.conv2d' op bias channels expected to be equal to output channels (8) or 1, got 7}}
1242+
%0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true}
1243+
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
1244+
return %0 : tensor<1x4x4x8xf32>
1245+
}

0 commit comments

Comments
 (0)