From 34a4e53d8be87ed1f8916edc29410c1606f5a839 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 2 Apr 2025 15:29:07 +0000 Subject: [PATCH] [mlir][tosa] Add error_if checks to clamp op verifier Specifically it introduces checks for: - ERROR_IF(max_val < min_val) - ERROR_IF(isNaN(min_val) || isNaN(max_val)) Change-Id: Id3fd81868df7ce7096c219bb61f903f1105039c5 Signed-off-by: Luke Hutton --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 17 +++++++++++++ mlir/test/Dialect/Tosa/invalid.mlir | 36 ++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index cdba332792eb0..de9adcb9ea1fc 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -663,6 +663,13 @@ LogicalResult tosa::ClampOp::verify() { (intMaxValAttr.getType() != inputETy)) return emitOpError("min/max attributes types are incompatible with " "input/output element types."); + + const bool isUnsigned = cast(inputETy).isUnsigned(); + const APInt minVal = intMinValAttr.getValue(); + const APInt maxVal = intMaxValAttr.getValue(); + if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal)) + return emitOpError("expected min_val <= max_val, got min_val=") + << minValAttr << ", max_val=" << maxValAttr; } else { // otherwise, input datatype is float, check that the min_val/max_val // attributes share the same type and that their type is the same as the @@ -674,6 +681,16 @@ LogicalResult tosa::ClampOp::verify() { (floatMaxValAttr.getType() != inputETy)) return emitOpError("min/max attributes types are incompatible with " "input/output element types."); + + const APFloat minVal = floatMinValAttr.getValue(); + const APFloat maxVal = floatMaxValAttr.getValue(); + if (minVal.isNaN() || maxVal.isNaN()) + return emitOpError("min/max attributes should not be 'NaN', got min_val=") + << minValAttr << ", max_val=" << maxValAttr; + + if (maxVal < minVal) + return emitOpError("expected min_val <= max_val, got min_val=") + << minValAttr << ", max_val=" << maxValAttr; } return success(); diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index ac8a247da24a7..3cac5eb06799d 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1939,3 +1939,39 @@ func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16> return %0 : tensor<13x21x3xi16> } + +// ----- + +// CHECK-LABEL: test_clamp_nan_min_val +func.func @test_clamp_nan_min_val(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.clamp' op min/max attributes should not be 'NaN', got min_val=0xFFFFFFFF : f32, max_val=1.000000e+00 : f32}} + %0 = tosa.clamp %arg0 {min_val = 0xFFFFFFFF : f32, max_val = 1.0: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_clamp_nan_max_val +func.func @test_clamp_nan_max_val(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.clamp' op min/max attributes should not be 'NaN', got min_val=2.300000e+00 : f32, max_val=0x7FFFFFFF : f32}} + %0 = tosa.clamp %arg0 {min_val = 2.3 : f32, max_val = 0x7FFFFFFF: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_clamp_min_larger_than_max_int8 +func.func @test_clamp_min_larger_than_max_int8(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { + // expected-error@+1 {{'tosa.clamp' op expected min_val <= max_val, got min_val=127 : i8, max_val=-128 : i8}} + %0 = tosa.clamp %arg0 {min_val = 127 : i8, max_val = -128: i8} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8> + return %0 : tensor<13x21x3xi8> +} + +// ----- + +// CHECK-LABEL: test_clamp_min_larger_than_max_fp32 +func.func @test_clamp_min_larger_than_max_fp32(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.clamp' op expected min_val <= max_val, got min_val=2.000000e+00 : f32, max_val=-1.100000e+00 : f32}} + %0 = tosa.clamp %arg0 {min_val = 2.0 : f32, max_val = -1.1: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +}