diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index c94edad62cac7..cc78aaed911e6 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -125,7 +125,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> { let arguments = (ins Tosa_Tensor4D:$input, - TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, + Tosa_Tensor4D:$weight, Tosa_Tensor1D:$bias, Tosa_ScalarIntOrFloatTensor:$input_zp, Tosa_ScalarIntOrFloatTensor:$weight_zp, @@ -172,7 +172,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> { let arguments = (ins Tosa_Tensor5D:$input, - TosaTensorRankOf<[Tosa_Weight], [5]>:$weight, + Tosa_Tensor5D:$weight, Tosa_Tensor1D:$bias, Tosa_ScalarIntOrFloatTensor:$input_zp, Tosa_ScalarIntOrFloatTensor:$weight_zp, @@ -218,7 +218,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> { let arguments = (ins Tosa_Tensor4D:$input, - TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, + Tosa_Tensor4D:$weight, Tosa_Tensor1D:$bias, Tosa_ScalarIntOrFloatTensor:$input_zp, Tosa_ScalarIntOrFloatTensor:$weight_zp, @@ -434,7 +434,7 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> { let arguments = (ins Tosa_Tensor4D:$input, - TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, + Tosa_Tensor4D:$weight, Tosa_Tensor1D:$bias, Tosa_ScalarIntOrFloatTensor:$input_zp, Tosa_ScalarIntOrFloatTensor:$weight_zp, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 67011f22fbe2a..b9ac1ff705514 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -84,11 +84,6 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>, def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat], "number">; -// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp, -// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp -def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8, - Tosa_QuantizedInt, AnyFloat]>; - //===----------------------------------------------------------------------===// // TOSA Tensor Conformance //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 107e09a5e7ab4..69b1627583234 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -278,19 +278,8 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc, template static LogicalResult verifyConvOp(T op) { - // All TOSA conv ops have an input and weight arguments which must be ranked - // tensors. - auto inputType = llvm::dyn_cast(op.getInput().getType()); - if (!inputType) { - op.emitOpError("expect a ranked tensor for input, got ") << op.getInput(); - return failure(); - } - - auto weightType = llvm::dyn_cast(op.getWeight().getType()); - if (!weightType) { - op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight(); - return failure(); - } + const auto inputType = llvm::dyn_cast(op.getInput().getType()); + const auto weightType = llvm::dyn_cast(op.getWeight().getType()); auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); @@ -3063,14 +3052,6 @@ LogicalResult TransposeConv2DOp::verify() { return emitOpError("expect all stride values to be >= 1, got [") << strides << "]"; - const auto inputType = llvm::dyn_cast(getInput().getType()); - - const auto outputType = - llvm::dyn_cast(getOutput().getType()); - - const auto weightType = - llvm::dyn_cast(getWeight().getType()); - const auto checkPadAgainstKernelDim = [this](int64_t pad_value, int64_t kernel_dim_size, llvm::StringRef pad_name, @@ -3084,69 +3065,77 @@ LogicalResult TransposeConv2DOp::verify() { }; const llvm::ArrayRef padding = getOutPad(); - const int64_t outPadTop = padding[0]; const int64_t outPadBottom = padding[1]; + const int64_t outPadLeft = padding[2]; + const int64_t outPadRight = padding[3]; - const int64_t kernelHeight = weightType.getDimSize(1); - - if (!ShapedType::isDynamic(kernelHeight)) { - if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight, "out_pad_top", - "KH"))) - return failure(); - - if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight, - "out_pad_bottom", "KH"))) - return failure(); - } + const auto weightType = + llvm::dyn_cast(getWeight().getType()); - const int64_t kernelWidth = weightType.getDimSize(2); + if (weightType) { + const int64_t kernelHeight = weightType.getDimSize(1); + if (!ShapedType::isDynamic(kernelHeight)) { + if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight, + "out_pad_top", "KH"))) + return failure(); - const int64_t outPadLeft = padding[2]; - const int64_t outPadRight = padding[3]; + if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight, + "out_pad_bottom", "KH"))) + return failure(); + } - if (!ShapedType::isDynamic(kernelWidth)) { - if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth, "out_pad_left", - "KW"))) - return failure(); + const int64_t kernelWidth = weightType.getDimSize(2); + if (!ShapedType::isDynamic(kernelWidth)) { + if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth, + "out_pad_left", "KW"))) + return failure(); - if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth, - "out_pad_right", "KW"))) - return failure(); + if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth, + "out_pad_right", "KW"))) + return failure(); + } } // Rest of the checks depend on the output type being a RankedTensorType + const auto outputType = + llvm::dyn_cast(getOutput().getType()); if (!outputType) return success(); - const int64_t inputHeight = inputType.getDimSize(1); - const int64_t outputHeight = outputType.getDimSize(1); - - if (!ShapedType::isDynamic(inputHeight) && - !ShapedType::isDynamic(outputHeight)) { - if (outputHeight != - (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight) - return emitOpError( - "dimension mismatch: expected OH == (IH - 1) * stride_y " - "+ out_pad_top + out_pad_bottom + KH, but got ") - << outputHeight << " != (" << inputHeight << " - 1) * " << strideY - << " + " << outPadTop << " + " << outPadBottom << " + " - << kernelHeight; - } + const auto inputType = llvm::dyn_cast(getInput().getType()); + if (inputType && weightType) { + const int64_t inputHeight = inputType.getDimSize(1); + const int64_t kernelHeight = weightType.getDimSize(1); + const int64_t outputHeight = outputType.getDimSize(1); + + if (!ShapedType::isDynamic(inputHeight) && + !ShapedType::isDynamic(outputHeight)) { + if (outputHeight != + (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight) + return emitOpError( + "dimension mismatch: expected OH == (IH - 1) * stride_y " + "+ out_pad_top + out_pad_bottom + KH, but got ") + << outputHeight << " != (" << inputHeight << " - 1) * " + << strideY << " + " << outPadTop << " + " << outPadBottom + << " + " << kernelHeight; + } - const int64_t inputWidth = inputType.getDimSize(2); - const int64_t outputWidth = outputType.getDimSize(2); + const int64_t inputWidth = inputType.getDimSize(2); + const int64_t kernelWidth = weightType.getDimSize(2); + const int64_t outputWidth = outputType.getDimSize(2); - if (!ShapedType::isDynamic(inputWidth) && - !ShapedType::isDynamic(outputWidth)) { - if (outputWidth != - (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth) - return emitOpError( - "dimension mismatch: expected OW == (IW - 1) * stride_x " - "+ out_pad_left + out_pad_right + KW, but got ") - << outputWidth << " != (" << inputWidth << " - 1) * " << strideX - << " + " << outPadLeft << " + " << outPadRight << " + " - << kernelWidth; + if (!ShapedType::isDynamic(inputWidth) && + !ShapedType::isDynamic(outputWidth)) { + if (outputWidth != + (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth) + return emitOpError( + "dimension mismatch: expected OW == (IW - 1) * stride_x " + "+ out_pad_left + out_pad_right + KW, but got ") + << outputWidth << " != (" << inputWidth << " - 1) * " << strideX + << " + " << outPadLeft << " + " << outPadRight << " + " + << kernelWidth; + } } const auto biasType = llvm::dyn_cast(getBias().getType()); diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 4a341d583426a..4957c1684994d 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -22,22 +22,12 @@ func.func @test_const_non_tensor_attr() { // ----- -func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { +func.func @test_conv2d(%arg0: tensor<*xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { %input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> %weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> // expected-error@+1 {{'tosa.conv2d' op expect both input and weight to be float or not together, got 'f32' and 'i8'}} %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array, pad = array, stride = array} - : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8> - return %0 : tensor<1x27x27x16xi8> -} - -// ----- - -func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { - %zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> - // expected-error@+1 {{'tosa.conv2d' op expect a ranked tensor for input, got of type 'tensor<*xi8>' at index: 0}} - %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array} - : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8> + : (tensor<*xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } @@ -45,7 +35,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { %zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> - // expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}} + // expected-error@+1 {{'tosa.conv2d' op illegal: operand/result data types not supported}} %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index b64074e412ed1..f8273190bde40 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -70,6 +70,13 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, % return %0 : tensor<1x4x4x8xf32> } +// ----- +// CHECK-LABEL: conv2d_unranked_input +func.func @test_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> { + %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<*xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} + // ----- // CHECK-LABEL: conv2d_quant_uniform func.func @test_conv2d_quant_uniform(%arg0: tensor<1x4x4x4x!quant.uniform>, %arg1: tensor<8x1x1x4x!quant.uniform>, %arg2: tensor<8x!quant.uniform>) -> tensor<1x4x4x8x!quant.uniform> { @@ -202,6 +209,20 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x return %0 : tensor<1x32x32x16xf32> } +// ----- +// CHECK-LABEL: transpose_conv2d_unranked_input +func.func @test_transpose_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> { + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<*xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- +// CHECK-LABEL: transpose_conv2d_unranked_weight +func.func @test_transpose_conv2d_unranked_weight(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<*xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> { + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<*xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + // ----- // CHECK-LABEL: transpose_conv2d_with_local_bound func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {