diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 607667fcc6945..dfab37497d5c2 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -524,6 +524,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( if (isa(op)) { Type srcTy = elementTy; Type dstTy = resultTypes.front(); + if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) { + (void)rewriter.notifyMatchFailure(op, "unsupported type"); + return nullptr; + } + bool bitExtend = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir index 460e207d62de6..5db3f56cf459e 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir @@ -54,3 +54,11 @@ func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3 %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> return %0 : tensor<2x3x4xf32> } + +// ----- + +func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> { + // expected-error@+1 {{failed to legalize operation 'tosa.cast'}} + %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +}