diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 7b5f88687..f6ae92b45 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -550,7 +550,7 @@ def is_not_equal_to( @register_rule("row") def is_not_less_than( - column: str | Column, limit: int | datetime.date | datetime.datetime | str | Column | None = None + column: str | Column, limit: int | float | datetime.date | datetime.datetime | str | Column | None = None ) -> Column: """Checks whether the values in the input column are not less than the provided limit. @@ -580,7 +580,7 @@ def is_not_less_than( @register_rule("row") def is_not_greater_than( - column: str | Column, limit: int | datetime.date | datetime.datetime | str | Column | None = None + column: str | Column, limit: int | float | datetime.date | datetime.datetime | str | Column | None = None ) -> Column: """Checks whether the values in the input column are not greater than the provided limit. @@ -611,8 +611,8 @@ def is_not_greater_than( @register_rule("row") def is_in_range( column: str | Column, - min_limit: int | datetime.date | datetime.datetime | str | Column | None = None, - max_limit: int | datetime.date | datetime.datetime | str | Column | None = None, + min_limit: int | float | datetime.date | datetime.datetime | str | Column | None = None, + max_limit: int | float | datetime.date | datetime.datetime | str | Column | None = None, ) -> Column: """Checks whether the values in the input column are in the provided limits (inclusive of both boundaries). @@ -649,15 +649,15 @@ def is_in_range( @register_rule("row") def is_not_in_range( column: str | Column, - min_limit: int | datetime.date | datetime.datetime | str | Column | None = None, - max_limit: int | datetime.date | datetime.datetime | str | Column | None = None, + min_limit: int | float | datetime.date | datetime.datetime | str | Column | None = None, + max_limit: int | float | datetime.date | datetime.datetime | str | Column | None = None, ) -> Column: """Checks whether the values in the input column are outside the provided limits (inclusive of both boundaries). Args: column: column to check; can be a string column name or a column expression min_limit: min limit to use in the condition as number, date, timestamp, column name or sql expression - max_limit: min limit to use in the condition as number, date, timestamp, column name or sql expression + max_limit: max limit to use in the condition as number, date, timestamp, column name or sql expression Returns: new Column diff --git a/tests/integration/test_row_checks.py b/tests/integration/test_row_checks.py index da5f0bd81..9b29d7330 100644 --- a/tests/integration/test_row_checks.py +++ b/tests/integration/test_row_checks.py @@ -678,13 +678,13 @@ def test_is_col_older_than_n_days_cur(spark): def test_col_is_not_less_than(spark, set_utc_timezone): - schema_num = "a: int, b: int, c: date, d: timestamp, e: decimal(10,2), f: array, g: map" + schema_num = "a: int, b: int, c: date, d: timestamp, e: decimal(10,2), f: array, g: map, h: float" test_df = spark.createDataFrame( [ - [1, 1, datetime(2025, 1, 1).date(), datetime(2025, 1, 1), Decimal("1.00"), [1], {"val": 1}], - [2, 4, datetime(2025, 2, 1).date(), datetime(2025, 2, 1), Decimal("1.99"), [2], {"val": 2}], - [4, 3, None, None, Decimal("2.01"), [4], {"val": 4}], - [None, None, None, None, None, [None], {"val": None}], + [1, 1, datetime(2025, 1, 1).date(), datetime(2025, 1, 1), Decimal("1.00"), [1], {"val": 1}, 1.2], + [2, 4, datetime(2025, 2, 1).date(), datetime(2025, 2, 1), Decimal("1.99"), [2], {"val": 2}, 3.6], + [4, 3, None, None, Decimal("2.01"), [4], {"val": 4}, 4.8], + [None, None, None, None, None, [None], {"val": None}, None], ], schema_num, ) @@ -698,13 +698,15 @@ def test_col_is_not_less_than(spark, set_utc_timezone): is_not_less_than("e", 2), is_not_less_than(F.try_element_at("f", F.lit(1)), 2), is_not_less_than(F.col("g").getItem("val"), 2), + is_not_less_than("h", 2.4), ) checked_schema = ( "a_less_than_limit: string, a_less_than_limit: string, b_less_than_limit: string, " "c_less_than_limit: string, d_less_than_limit: string, e_less_than_limit: string, " "try_element_at_f_1_less_than_limit: string, " - "unresolvedextractvalue_g_val_less_than_limit: string" + "unresolvedextractvalue_g_val_less_than_limit: string, " + "h_less_than_limit: string" ) expected = spark.createDataFrame( @@ -718,6 +720,7 @@ def test_col_is_not_less_than(spark, set_utc_timezone): "Value '1.00' in Column 'e' is less than limit: 2", "Value '1' in Column 'try_element_at(f, 1)' is less than limit: 2", "Value '1' in Column 'UnresolvedExtractValue(g, val)' is less than limit: 2", + "Value '1.2' in Column 'h' is less than limit: 2.4", ], [ None, @@ -728,6 +731,7 @@ def test_col_is_not_less_than(spark, set_utc_timezone): "Value '1.99' in Column 'e' is less than limit: 2", None, None, + None, ], [ None, @@ -738,8 +742,9 @@ def test_col_is_not_less_than(spark, set_utc_timezone): None, None, None, + None, ], - [None, None, None, None, None, None, None, None], + [None, None, None, None, None, None, None, None, None], ], checked_schema, ) @@ -748,13 +753,13 @@ def test_col_is_not_less_than(spark, set_utc_timezone): def test_col_is_not_greater_than(spark, set_utc_timezone): - schema_num = "a: int, b: int, c: date, d: timestamp, e: decimal(10,2), f: array" + schema_num = "a: int, b: int, c: date, d: timestamp, e: decimal(10,2), f: array, g: float" test_df = spark.createDataFrame( [ - [1, 1, datetime(2025, 1, 1).date(), datetime(2025, 1, 1), Decimal("1.00"), [1]], - [2, 4, datetime(2025, 2, 1).date(), datetime(2025, 2, 1), Decimal("1.01"), [2]], - [8, 3, None, None, Decimal("0.99"), [8]], - [None, None, None, None, None, [None]], + [1, 1, datetime(2025, 1, 1).date(), datetime(2025, 1, 1), Decimal("1.00"), [1], 1.2], + [2, 4, datetime(2025, 2, 1).date(), datetime(2025, 2, 1), Decimal("1.01"), [2], 3.6], + [8, 3, None, None, Decimal("0.99"), [8], 4.8], + [None, None, None, None, None, [None], None], ], schema_num, ) @@ -767,16 +772,17 @@ def test_col_is_not_greater_than(spark, set_utc_timezone): is_not_greater_than("d", datetime(2025, 1, 1)), is_not_greater_than("e", 1), is_not_greater_than(F.try_element_at("f", F.lit(1)), 1), + is_not_greater_than("g", 2.4), ) checked_schema = ( "a_greater_than_limit: string, a_greater_than_limit: string, b_greater_than_limit: string, " "c_greater_than_limit: string, d_greater_than_limit: string, e_greater_than_limit: string, " - "try_element_at_f_1_greater_than_limit: string" + "try_element_at_f_1_greater_than_limit: string, g_greater_than_limit: string" ) expected = spark.createDataFrame( [ - [None, None, None, None, None, None, None], + [None, None, None, None, None, None, None, None], [ "Value '2' in Column 'a' is greater than limit: 1", None, @@ -785,6 +791,7 @@ def test_col_is_not_greater_than(spark, set_utc_timezone): "Value '2025-02-01 00:00:00' in Column 'd' is greater than limit: 2025-01-01 00:00:00", "Value '1.01' in Column 'e' is greater than limit: 1", "Value '2' in Column 'try_element_at(f, 1)' is greater than limit: 1", + "Value '3.6' in Column 'g' is greater than limit: 2.4", ], [ "Value '8' in Column 'a' is greater than limit: 1", @@ -794,8 +801,9 @@ def test_col_is_not_greater_than(spark, set_utc_timezone): None, None, "Value '8' in Column 'try_element_at(f, 1)' is greater than limit: 1", + "Value '4.8' in Column 'g' is greater than limit: 2.4", ], - [None, None, None, None, None, None, None], + [None, None, None, None, None, None, None, None], ], checked_schema, ) @@ -804,15 +812,15 @@ def test_col_is_not_greater_than(spark, set_utc_timezone): def test_col_is_in_range(spark, set_utc_timezone): - schema_num = "a: int, b: date, c: timestamp, d: int, e: int, f: int, g: decimal(10,2), h: map" + schema_num = "a: int, b: date, c: timestamp, d: int, e: int, f: int, g: decimal(10,2), h: map, i:float" test_df = spark.createDataFrame( [ - [0, datetime(2024, 12, 1).date(), datetime(2024, 12, 1), -1, 5, 6, Decimal("2.00"), {"val": 0}], - [1, datetime(2025, 1, 1).date(), datetime(2025, 1, 1), 2, 6, 3, Decimal("1.00"), {"val": 1}], - [2, datetime(2025, 2, 1).date(), datetime(2025, 2, 1), 2, 7, 3, Decimal("3.00"), {"val": 2}], - [3, datetime(2025, 3, 1).date(), datetime(2025, 3, 1), 3, 8, 3, Decimal("1.01"), {"val": 3}], - [4, datetime(2025, 4, 1).date(), datetime(2025, 4, 1), 2, 9, 3, Decimal("3.01"), {"val": 4}], - [None, None, None, None, None, None, None, {"val": None}], + [0, datetime(2024, 12, 1).date(), datetime(2024, 12, 1), -1, 5, 6, Decimal("2.00"), {"val": 0}, 0.0], + [1, datetime(2025, 1, 1).date(), datetime(2025, 1, 1), 2, 6, 3, Decimal("1.00"), {"val": 1}, 0.2], + [2, datetime(2025, 2, 1).date(), datetime(2025, 2, 1), 2, 7, 3, Decimal("3.00"), {"val": 2}, 0.4], + [3, datetime(2025, 3, 1).date(), datetime(2025, 3, 1), 3, 8, 3, Decimal("1.01"), {"val": 3}, 0.6], + [4, datetime(2025, 4, 1).date(), datetime(2025, 4, 1), 2, 9, 3, Decimal("3.01"), {"val": 4}, 0.8], + [None, None, None, None, None, None, None, {"val": None}, None], ], schema_num, ) @@ -827,12 +835,13 @@ def test_col_is_in_range(spark, set_utc_timezone): is_in_range("f", "a", 5), is_in_range("g", 1, 3), is_in_range(F.col("h").getItem("val"), 1, 3), + is_in_range("i", 0.1, 0.7), ) checked_schema = ( "a_not_in_range: string, b_not_in_range: string, c_not_in_range: string, " "d_not_in_range: string, f_not_in_range: string, g_not_in_range: string, " - "unresolvedextractvalue_h_val_not_in_range: string" + "unresolvedextractvalue_h_val_not_in_range: string, i_not_in_range: string" ) expected = spark.createDataFrame( [ @@ -844,10 +853,11 @@ def test_col_is_in_range(spark, set_utc_timezone): "Value '6' in Column 'f' not in range: [0, 5]", None, "Value '0' in Column 'UnresolvedExtractValue(h, val)' not in range: [1, 3]", + "Value '0.0' in Column 'i' not in range: [0.1, 0.7]", ], - [None, None, None, None, None, None, None], - [None, None, None, None, None, None, None], - [None, None, None, None, None, None, None], + [None, None, None, None, None, None, None, None], + [None, None, None, None, None, None, None, None], + [None, None, None, None, None, None, None, None], [ "Value '4' in Column 'a' not in range: [1, 3]", "Value '2025-04-01' in Column 'b' not in range: [2025-01-01, 2025-03-01]", @@ -856,8 +866,9 @@ def test_col_is_in_range(spark, set_utc_timezone): "Value '3' in Column 'f' not in range: [4, 5]", "Value '3.01' in Column 'g' not in range: [1, 3]", "Value '4' in Column 'UnresolvedExtractValue(h, val)' not in range: [1, 3]", + "Value '0.8' in Column 'i' not in range: [0.1, 0.7]", ], - [None, None, None, None, None, None, None], + [None, None, None, None, None, None, None, None], ], checked_schema, ) @@ -866,13 +877,21 @@ def test_col_is_in_range(spark, set_utc_timezone): def test_col_is_not_in_range(spark, set_utc_timezone): - schema_num = "a: int, b: date, c: timestamp, d: timestamp, e: decimal(10,2), f: array" + schema_num = "a: int, b: date, c: timestamp, d: timestamp, e: decimal(10,2), f: array, g: float" test_df = spark.createDataFrame( [ - [0, datetime(2024, 12, 31).date(), datetime(2025, 1, 4), datetime(2025, 1, 7), Decimal("0.99"), [0, 1]], - [1, datetime(2025, 1, 1).date(), datetime(2025, 1, 3), datetime(2025, 1, 1), Decimal("1.00"), [1, 2]], - [3, datetime(2025, 2, 1).date(), datetime(2025, 2, 1), datetime(2025, 2, 3), Decimal("3.00"), [3, 4]], - [None, None, None, None, None, [None, 1]], + [ + 0, + datetime(2024, 12, 31).date(), + datetime(2025, 1, 4), + datetime(2025, 1, 7), + Decimal("0.99"), + [0, 1], + 0.0, + ], + [1, datetime(2025, 1, 1).date(), datetime(2025, 1, 3), datetime(2025, 1, 1), Decimal("1.00"), [1, 2], 0.3], + [3, datetime(2025, 2, 1).date(), datetime(2025, 2, 1), datetime(2025, 2, 3), Decimal("3.00"), [3, 4], 0.6], + [None, None, None, None, None, [None, 1], None], ], schema_num, ) @@ -886,15 +905,16 @@ def test_col_is_not_in_range(spark, set_utc_timezone): is_not_in_range("d", "c", F.expr("cast(b as timestamp) + INTERVAL 2 DAY")), is_not_in_range("e", 1, 3), is_not_in_range(F.try_element_at("f", F.lit(1)), 1, 3), + is_not_in_range("g", 0.2, 0.5), ) checked_schema = ( "a_in_range: string, b_in_range: string, c_in_range: string, d_in_range: string, e_in_range: string, " - "try_element_at_f_1_in_range: string" + "try_element_at_f_1_in_range: string, g_in_range: string" ) expected = spark.createDataFrame( [ - [None, None, None, None, None, None], + [None, None, None, None, None, None, None], [ "Value '1' in Column 'a' in range: [1, 3]", "Value '2025-01-01' in Column 'b' in range: [2025-01-01, 2025-01-03]", @@ -902,6 +922,7 @@ def test_col_is_not_in_range(spark, set_utc_timezone): None, "Value '1.00' in Column 'e' in range: [1, 3]", "Value '1' in Column 'try_element_at(f, 1)' in range: [1, 3]", + "Value '0.3' in Column 'g' in range: [0.2, 0.5]", ], [ "Value '3' in Column 'a' in range: [1, 3]", @@ -910,8 +931,9 @@ def test_col_is_not_in_range(spark, set_utc_timezone): "Value '2025-02-03 00:00:00' in Column 'd' in range: [2025-02-01 00:00:00, 2025-02-03 00:00:00]", "Value '3.00' in Column 'e' in range: [1, 3]", "Value '3' in Column 'try_element_at(f, 1)' in range: [1, 3]", + None, ], - [None, None, None, None, None, None], + [None, None, None, None, None, None, None], ], checked_schema, ) diff --git a/tests/unit/test_checks_validation.py b/tests/unit/test_checks_validation.py index 080d602de..936a2d814 100644 --- a/tests/unit/test_checks_validation.py +++ b/tests/unit/test_checks_validation.py @@ -433,3 +433,26 @@ def test_argument_type_list_mismatch_args(): "Item 1 in argument 'columns' should be of type 'str | pyspark.sql.column.Column' " "for function 'is_unique' in the 'arguments' block" in str(status) ) + + +def test_is_in_range_float_arguments(): + checks = [ + { + "criticality": "warn", + "check": {"function": "is_in_range", "arguments": {"column": "a", "min_limit": 1.5, "max_limit": 2.5}}, + }, + { + "criticality": "warn", + "check": {"function": "is_in_range", "arguments": {"column": "b", "min_limit": 0.1, "max_limit": 0.9}}, + }, + { + "criticality": "warn", + "check": {"function": "is_not_in_range", "arguments": {"column": "c", "min_limit": 1.5, "max_limit": 2.5}}, + }, + { + "criticality": "warn", + "check": {"function": "is_not_in_range", "arguments": {"column": "c", "min_limit": 0.1, "max_limit": 0.9}}, + }, + ] + status = DQEngine.validate_checks(checks) + assert not status.has_errors