diff --git a/pointblank/_utils.py b/pointblank/_utils.py index afb6a3958..a2efc9d31 100644 --- a/pointblank/_utils.py +++ b/pointblank/_utils.py @@ -229,6 +229,77 @@ def _check_column_exists(dfn: nw.DataFrame, column: str) -> None: raise ValueError(f"Column '{column}' not found in DataFrame.") +def _count_true_values_in_column( + tbl: FrameT, + column: str, + inverse: bool = False, +) -> int: + """ + Count the number of `True` values in a specified column of a table. + + Parameters + ---------- + tbl + A Narwhals-compatible DataFrame or table-like object. + column + The column in which to count the `True` values. + inverse + If `True`, count the number of `False` values instead. + + Returns + ------- + int + The count of `True` (or `False`) values in the specified column. + """ + + # Convert the DataFrame to a Narwhals DataFrame (no detrimental effect if + # already a Narwhals DataFrame) + tbl_nw = nw.from_native(tbl) + + # Filter the table based on the column and whether we want to count True or False values + tbl_filtered = tbl_nw.filter(nw.col(column) if not inverse else ~nw.col(column)) + + # Always collect table if it is a LazyFrame; this is required to get the row count + if _is_lazy_frame(tbl_filtered): + tbl_filtered = tbl_filtered.collect() + + return len(tbl_filtered) + + +def _count_null_values_in_column( + tbl: FrameT, + column: str, +) -> int: + """ + Count the number of Null values in a specified column of a table. + + Parameters + ---------- + tbl + A Narwhals-compatible DataFrame or table-like object. + column + The column in which to count the Null values. + + Returns + ------- + int + The count of Null values in the specified column. + """ + + # Convert the DataFrame to a Narwhals DataFrame (no detrimental effect if + # already a Narwhals DataFrame) + tbl_nw = nw.from_native(tbl) + + # Filter the table to get rows where the specified column is Null + tbl_filtered = tbl_nw.filter(nw.col(column).is_null()) + + # Always collect table if it is a LazyFrame; this is required to get the row count + if _is_lazy_frame(tbl_filtered): + tbl_filtered = tbl_filtered.collect() + + return len(tbl_filtered) + + def _is_numeric_dtype(dtype: str) -> bool: """ Check if a given data type string represents a numeric type. diff --git a/pointblank/validate.py b/pointblank/validate.py index bb73857c2..558b5ea20 100644 --- a/pointblank/validate.py +++ b/pointblank/validate.py @@ -64,6 +64,8 @@ from pointblank._utils import ( _check_any_df_lib, _check_invalid_fields, + _count_null_values_in_column, + _count_true_values_in_column, _derive_bounds, _format_to_integer_value, _get_fn_name, @@ -8504,36 +8506,32 @@ def interrogate( else: # If the result is not a list, then we assume it's a table in the conventional - # form (where the column is `pb_is_good_` exists, with boolean values) - + # form (where the column is `pb_is_good_` exists, with boolean values results_tbl = results_tbl_list # If the results table is not `None`, then we assume there is a table with a column # called `pb_is_good_` that contains boolean values; we can then use this table to # determine the number of test units that passed and failed if results_tbl is not None: - # Extract the `pb_is_good_` column from the table as a results list - if tbl_type in IBIS_BACKENDS: - # Select the DataFrame library to use for getting the results list - df_lib = _select_df_lib(preference="polars") - df_lib_name = df_lib.__name__ - - if df_lib_name == "pandas": - results_list = ( - results_tbl.select("pb_is_good_").to_pandas()["pb_is_good_"].to_list() - ) - else: - results_list = ( - results_tbl.select("pb_is_good_").to_polars()["pb_is_good_"].to_list() - ) + # Count the number of passing and failing test units + validation.n_passed = _count_true_values_in_column( + tbl=results_tbl, column="pb_is_good_" + ) + validation.n_failed = _count_true_values_in_column( + tbl=results_tbl, column="pb_is_good_", inverse=True + ) - else: - results_list = nw.from_native(results_tbl)["pb_is_good_"].to_list() + # Solely for the col_vals_in_set assertion type, any Null values in the + # `pb_is_good_` column are counted as failing test units + if assertion_type == "col_vals_in_set": + null_count = _count_null_values_in_column(tbl=results_tbl, column="pb_is_good_") + validation.n_failed += null_count + + # For column-value validations, the number of test units is the number of rows + validation.n = get_row_count(data=results_tbl) - validation.all_passed = all(results_list) - validation.n = len(results_list) - validation.n_passed = results_list.count(True) - validation.n_failed = results_list.count(False) + # Set the `all_passed` attribute based on whether there are any failing test units + validation.all_passed = validation.n_failed == 0 # Calculate fractions of passing and failing test units # - `f_passed` is the fraction of test units that passed diff --git a/pyproject.toml b/pyproject.toml index d6d59ff1b..7ea91b0e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "commonmark>=0.9.1", "importlib-metadata", "great_tables>=0.17.0", - "narwhals>=1.24.1", + "narwhals>=1.41.0", "typing_extensions>=3.10.0.0", "requests>=2.31.0", ] diff --git a/tests/test__utils.py b/tests/test__utils.py index 25a9f456c..fb9b4828e 100644 --- a/tests/test__utils.py +++ b/tests/test__utils.py @@ -13,6 +13,8 @@ _check_column_type, _check_invalid_fields, _column_test_prep, + _count_true_values_in_column, + _count_null_values_in_column, _format_to_float_value, _format_to_integer_value, _get_assertion_from_fname, @@ -28,6 +30,8 @@ _select_df_lib, ) +from pointblank.validate import load_dataset + @pytest.fixture def tbl_pd(): @@ -325,6 +329,21 @@ def test_check_column_test_prep_raises(request, tbl_fixture): _column_test_prep(df=tbl, column="invalid", allowed_types=["numeric"]) +@pytest.mark.parametrize("tbl_type", ["polars", "duckdb"]) +def test_count_true_values_in_column(tbl_type): + data = load_dataset(dataset="small_table", tbl_type=tbl_type) + + assert _count_true_values_in_column(tbl=data, column="e") == 8 + assert _count_true_values_in_column(tbl=data, column="e", inverse=True) == 5 + + +@pytest.mark.parametrize("tbl_type", ["polars", "duckdb"]) +def test_count_null_values_in_column(tbl_type): + data = load_dataset(dataset="small_table", tbl_type=tbl_type) + + assert _count_null_values_in_column(tbl=data, column="c") == 2 + + def test_format_to_integer_value(): assert _format_to_integer_value(0) == "0" assert _format_to_integer_value(0.3) == "0" diff --git a/tests/test_validate.py b/tests/test_validate.py index e74c6b4ec..2d3e550d3 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -374,14 +374,14 @@ def test_validate_class_lang_locale(): def test_null_vals_in_set(data: Any) -> None: validate = ( Validate(data) - .col_vals_in_set(["foo"], set=[1, 2, None]) - .col_vals_in_set(["bar"], set=["winston", "cat", None]) + .col_vals_in_set(columns="foo", set=[1, 2, None]) + .col_vals_in_set(columns="bar", set=["winston", "cat", None]) .interrogate() ) validate.assert_passing() - validate = Validate(data).col_vals_in_set("foo", [1, 2]).interrogate() + validate = Validate(data).col_vals_in_set(columns="foo", set=[1, 2]).interrogate() with pytest.raises(AssertionError): validate.assert_passing()