diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 2953e006bbe8d..92ab729f5b933 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -42,7 +42,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> { let arguments = (ins Tosa_Tensor: $input, - I32Attr: $axis + I32Attr: $axis, + DefaultValuedAttr:$nan_mode ); let results = (outs @@ -287,7 +288,8 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> { Tosa_IntArrayAttr2:$kernel, Tosa_IntArrayAttr2:$stride, - Tosa_IntArrayAttr4:$pad + Tosa_IntArrayAttr4:$pad, + DefaultValuedAttr:$nan_mode ); let results = (outs @@ -388,7 +390,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> { I64Attr:$min_int, I64Attr:$max_int, Tosa_FloatAttr:$min_fp, - Tosa_FloatAttr:$max_fp + Tosa_FloatAttr:$max_fp, + DefaultValuedAttr:$nan_mode ); let results = (outs @@ -752,7 +755,8 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [ let arguments = (ins Tosa_Tensor:$input1, - Tosa_Tensor:$input2 + Tosa_Tensor:$input2, + DefaultValuedAttr:$nan_mode ); let results = (outs @@ -775,7 +779,8 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [ let arguments = (ins Tosa_Tensor:$input1, - Tosa_Tensor:$input2 + Tosa_Tensor:$input2, + DefaultValuedAttr:$nan_mode ); let results = (outs @@ -1382,7 +1387,8 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> { let arguments = (ins Tosa_Tensor:$input, - I32Attr:$axis + I32Attr:$axis, + DefaultValuedAttr:$nan_mode ); let results = (outs @@ -1417,7 +1423,8 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> { let arguments = (ins Tosa_Tensor:$input, - I32Attr:$axis + I32Attr:$axis, + DefaultValuedAttr:$nan_mode ); let results = (outs diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 13325fb0ab9a2..5693acf3a01db 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -205,12 +205,20 @@ def Tosa_FloatAttr : Attr($_self)">, //===----------------------------------------------------------------------===// // Iterable attributes. //===----------------------------------------------------------------------===// +// Defined in `section 3. Enumerations` of the TOSA specification. + // Supported regimes for tosa.resize. def Tosa_ResizeTypeAttr : StringBasedAttr< CPred<"::llvm::cast($_self).getValue() == \"BILINEAR\" || " # "::llvm::cast($_self).getValue() == \"NEAREST_NEIGHBOR\"">, "Supported resize/upsampling strategies">; +// Supported NaN propagation strategies. +def Tosa_NanPropagationAttr : StringBasedAttr< + CPred<"::llvm::cast($_self).getValue() == \"PROPAGATE\" || " # + "::llvm::cast($_self).getValue() == \"IGNORE\"">, + "Supported NaN propagation strategies">; + def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">; // Tensor to buffer types. diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index f7a596f1ccb19..8b883487d1659 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -339,33 +339,84 @@ struct ClampIsNoOp : public OpRewritePattern { } }; +// Attempts the following transformation: +// +// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input +// tensor X the following identity holds: +// +// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b')) +// +// subject to the following valid NaN propagation semantics: +// -------------------------------------------- +// | OUTER CLAMP | INNER CLAMP | RESULT MODE | +// |-------------|--------------|-------------| +// | PROPAGATE | PROPAGATE | PROPAGATE | +// | PROPAGATE | IGNORE | IGNORE | +// | IGNORE | PROPAGATE | INVALID | +// | IGNORE | IGNORE | IGNORE | +// |------------------------------------------| + struct ClampClampOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + // Helper structure to describe the range of a clamp operation. + template + struct ClampRange { + ClampRange(const T &start, const T &end) : start(start), end(end) {} + T start; + T end; + + // Helper function to determine if two Clamp ranges intersect. + bool intersects(const ClampRange &otherRange) { + return start < otherRange.end && otherRange.start < end; + } + }; + LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override { - Value input = op.getInput(); - - Operation *definingOp = input.getDefiningOp(); - if (!definingOp) + // Check the input to the CLAMP op is itself a CLAMP. + auto clampOp = + dyn_cast_if_present(op.getInput().getDefiningOp()); + if (!clampOp) return failure(); - if (tosa::ClampOp clampOp = dyn_cast(definingOp)) { - auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat(); - auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat(); + // Check we have a valid NaN propagation combination. + const auto opNanMode = op.getNanMode(); + const auto clampNanMode = clampOp.getNanMode(); + if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE") + return failure(); - auto minInt = std::max(op.getMinInt(), clampOp.getMinInt()); - auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt()); + // Check we have intersecting ranges. + const auto opMinInt = op.getMinInt(); + const auto opMaxInt = op.getMaxInt(); + const auto clampOpMinInt = clampOp.getMinInt(); + const auto clampOpMaxInt = clampOp.getMaxInt(); + ClampRange opRangeIntRange(opMinInt, opMaxInt); + ClampRange clampRangeIntRange(clampOpMinInt, clampOpMaxInt); + if (!opRangeIntRange.intersects(clampRangeIntRange)) + return failure(); - rewriter.replaceOpWithNewOp( - op, op.getType(), clampOp.getInput(), - rewriter.getI64IntegerAttr(minInt), - rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), - rewriter.getF32FloatAttr(maxFp)); - return success(); - } + const auto opMinFloat = op.getMinFp(); + const auto opMaxFloat = op.getMaxFp(); + const auto clampOpMinFloat = clampOp.getMinFp(); + const auto clampOpMaxFloat = clampOp.getMaxFp(); + ClampRange opRangeFloatRange(opMinFloat, opMaxFloat); + ClampRange clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat); + if (!opRangeFloatRange.intersects(clampRangeFloatRange)) + return failure(); - return failure(); + // Run the transformation. + const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat(); + const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat(); + const auto minInt = std::max(opMinInt, clampOpMinInt); + const auto maxInt = std::min(opMaxInt, clampOpMaxInt); + rewriter.replaceOpWithNewOp( + op, op.getType(), clampOp.getInput(), + rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt), + rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp), + rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE" + : opNanMode)); + return success(); } }; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index e394188e9a931..6f47f041b9199 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -138,6 +138,58 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { // ----- +// CHECK: @disjoint_clamp_twice_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>) +func.func @disjoint_clamp_twice_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { + // CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = -5.000000e+00 : f32, max_int = -5 : i64, min_fp = -1.000000e+00 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8> + // CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 1.000000e+00 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8> + %0 = tosa.clamp %arg0 {max_fp = -5.0 : f32, max_int = -5 : i64, min_fp = -1.0 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 5 : i64, min_fp = 1.0 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8> + return %1 : tensor<4xi8> +} + +// ----- + +// CHECK-LABEL: @clamp_twice_with_nan_propagate_is_single_clamp +func.func @clamp_twice_with_nan_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { + // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64} + %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> + return %1 : tensor<4xi8> +} + +// ----- + +// CHECK-LABEL: @clamp_twice_with_nan_ignore_is_single_clamp +func.func @clamp_twice_with_nan_ignore_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { + // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} + %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> + return %1 : tensor<4xi8> +} + +// ----- + +// CHECK-LABEL: @clamp_twice_with_nan_ignore_propagate_is_single_clamp +func.func @clamp_twice_with_nan_ignore_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { + // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} + %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> + return %1 : tensor<4xi8> +} + +// ----- + +// CHECK: @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>) +func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { + // CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = 3.000000e+00 : f32, max_int = 4 : i64, min_fp = -5.000000e+00 : f32, min_int = -2 : i64} : (tensor<4xi8>) -> tensor<4xi8> + // CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> + %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> + return %1 : tensor<4xi8> +} + +// ----- + // CHECK-LABEL: @concat_fold func.func @concat_fold(%arg0: tensor) -> tensor { // CHECK: return %arg0 diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 563c5fa457d35..19b93d7611854 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -180,6 +180,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { return %0 : tensor<13x21x3xf32> } +// ----- +// CHECK-LABEL: clamp_propagate +func.func @test_clamp_propagate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "PROPAGATE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: clamp_ignore +func.func @test_clamp_ignore(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "IGNORE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + // ----- // CHECK-LABEL: clamp_f16 func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {