diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index fd166cc1322ce..af4a5dc96265e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) { bool biasIsFloat = llvm::isa(biasEType); bool resultIsFloat = llvm::isa(resultEType); - if (auto quantType = - llvm::dyn_cast(inputEType)) + if (auto quantType = llvm::dyn_cast(inputEType)) inputEType = quantType.getStorageType(); - if (auto quantType = - llvm::dyn_cast(biasEType)) + if (auto quantType = llvm::dyn_cast(biasEType)) biasEType = quantType.getStorageType(); - if (auto quantType = - llvm::dyn_cast(resultEType)) + if (auto quantType = llvm::dyn_cast(resultEType)) resultEType = quantType.getStorageType(); if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) { @@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) { auto inputEType = llvm::cast(op.getInput().getType()).getElementType(); - if (auto quantType = - llvm::dyn_cast(inputEType)) + if (auto quantType = llvm::dyn_cast(inputEType)) inputEType = quantType.getStorageType(); auto accType = op.getAccType(); @@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) { if (inputEType.isF32() && !accType.isF32()) return op.emitOpError("accumulator type for f32 tensor is not f32"); - return success(); + auto resultEType = + llvm::cast(op.getResult().getType()).getElementType(); + + if (auto quantType = llvm::dyn_cast(resultEType)) + resultEType = quantType.getStorageType(); + + // check allowed input/result element types combinations + if ((inputEType.isInteger(8) && resultEType.isInteger(32)) || + (inputEType.isInteger(16) && resultEType.isInteger(48)) || + (isa(inputEType) && resultEType.isF16()) || + (isa(inputEType) && resultEType.isF16()) || + (inputEType.isF16() && resultEType.isF16()) || + (inputEType.isBF16() && resultEType.isBF16()) || + (inputEType.isF32() && resultEType.isF32())) + return success(); + + return op.emitOpError("input/output element types are incompatible."); } // verify that inType and outType have same element types diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 2165e1f7ae3ba..20fc10d77d0e0 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -144,6 +144,24 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1 return %0 : tensor<1x32x32x16xi8> } +// ----- +// CHECK-LABEL: conv2d_quant_any_acc +func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any>>, %arg1: tensor<8x1x1x4x!quant.any>>, %arg2: tensor<8x!quant.any>>) -> tensor<1x4x4x8x!quant.any>> { + %zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}} + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4x!quant.any>>, tensor<8x1x1x4x!quant.any>>, tensor<8x!quant.any>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any>> + return %0 : tensor<1x4x4x8x!quant.any>> +} + +// ----- +// CHECK-LABEL: conv2d_quant_any_result +func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any>>, %arg1: tensor<8x1x1x4x!quant.any>>, %arg2: tensor<8x!quant.any>>) -> tensor<1x4x4x8x!quant.any>> { + %zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.conv2d' op input/output element types are incompatible}} + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4x!quant.any>>, tensor<8x1x1x4x!quant.any>>, tensor<8x!quant.any>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any>> + return %0 : tensor<1x4x4x8x!quant.any>> +} + // ----- func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor { diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index baf09e089aa30..d7e4f682c28b3 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -58,6 +58,22 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, % return %0 : tensor<1x4x4x8xf32> } +// ----- +// CHECK-LABEL: conv2d_quant_uniform +func.func @test_conv2d_quant_uniform(%arg0: tensor<1x4x4x4x!quant.uniform>, %arg1: tensor<8x1x1x4x!quant.uniform>, %arg2: tensor<8x!quant.uniform>) -> tensor<1x4x4x8x!quant.uniform> { + %zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4x!quant.uniform>, tensor<8x1x1x4x!quant.uniform>, tensor<8x!quant.uniform>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.uniform> + return %0 : tensor<1x4x4x8x!quant.uniform> +} + +// ----- +// CHECK-LABEL: conv2d_quant_any +func.func @test_conv2d_quant_any(%arg0: tensor<1x4x4x4x!quant.any>>, %arg1: tensor<8x1x1x4x!quant.any>>, %arg2: tensor<8x!quant.any>>) -> tensor<1x4x4x8x!quant.any>> { + %zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4x!quant.any>>, tensor<8x1x1x4x!quant.any>>, tensor<8x!quant.any>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any>> + return %0 : tensor<1x4x4x8x!quant.any>> +} + // ----- // CHECK-LABEL: conv2d_q8xi4 func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> {