Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pointblank/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"not_null": ["str", "numeric", "bool", "datetime", "duration"],
}

ASSERTION_TYPE_METHOD_MAP = {
ASSERTION_TYPE_METHOD_MAP: dict[str, str] = {
"col_vals_pct_null": "pct_null",
"col_vals_gt": "gt",
"col_vals_lt": "lt",
"col_vals_eq": "eq",
Expand Down
28 changes: 28 additions & 0 deletions pointblank/_interrogation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

Expand All @@ -9,6 +10,7 @@
from narwhals.typing import FrameT

from pointblank._constants import IBIS_BACKENDS
from pointblank._typing import AbsoluteBounds
from pointblank._utils import (
_column_test_prep,
_convert_to_narwhals,
Expand Down Expand Up @@ -736,6 +738,32 @@ def row_count_match(data_tbl: FrameT, count, inverse: bool, abs_tol_bounds) -> b
return row_count >= min_val and row_count <= max_val


def col_vals_pct_null(
data_tbl: FrameT, column: str, p: float, bound_finder: Callable[[int], AbsoluteBounds]
) -> bool:
"""Check if the pct of null vales are within p given the absolute bounds.

Args:
data_tbl (FrameT): Data
column (str): Column in the data.
p (float): Percentage of null values out of the total allowed.
bound_finder (Callable[[int], AbsoluteBounds]): Function that takes a target number of
null values and returns the absolute bounds.

Returns:
bool: _description_
"""
# TODO: Shouldn't be passing the whole dataframe for things like this.
# Extract the absolute target to use with the absolute bounds.
total_rows: int = data_tbl[column].len()
abs_target: float = round(total_rows * p)
lower_bound, upper_bound = bound_finder(abs_target)

n_null: int = nw.from_native(data_tbl).select(nw.col(column).is_null().sum()).item()

return n_null >= (abs_target - lower_bound) and n_null <= (abs_target + upper_bound)


def col_count_match(data_tbl: FrameT, count, inverse: bool) -> bool:
"""
Check if DataFrame column count matches expected count.
Expand Down
67 changes: 66 additions & 1 deletion pointblank/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import threading
from dataclasses import dataclass
from enum import Enum
from functools import partial
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Callable, Literal
from zipfile import ZipFile
Expand Down Expand Up @@ -53,6 +54,7 @@
col_exists,
col_schema_match,
col_vals_expr,
col_vals_pct_null,
conjointly_validation,
interrogate_between,
interrogate_eq,
Expand Down Expand Up @@ -3963,6 +3965,55 @@ def set_tbl(
def _repr_html_(self) -> str:
return self.get_tabular_report()._repr_html_() # pragma: no cover

def col_vals_pct_null(
self,
columns: str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals,
p: float,
tol: Tolerance = 0,
thresholds: int | float | None | bool | tuple | dict | Thresholds = None,
brief: str | bool | None = None,
) -> Validate:
"""Assert the number of values in a column are null.

Args:
columns (str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals): _description_
p (float): Percentage that should be null.
tol (Tolerance, optional): Tolerance allowed against the total dataset.
"""
# If `columns` is a ColumnSelector or Narwhals selector, call `col()` on it to later
# resolve the columns
if isinstance(columns, (ColumnSelector, nw.selectors.Selector)):
columns = col(columns)

# If `columns` is Column value or a string, place it in a list for iteration
if isinstance(columns, (Column, str)):
columns = [columns]

# Determine brief to use (global or local) and transform any shorthands of `brief=`
brief = self.brief if brief is None else _transform_auto_brief(brief=brief)

bound_finder: Callable[[int], AbsoluteBounds] = partial(_derive_bounds, tol=tol)

thresholds = (
self.thresholds if thresholds is None else _normalize_thresholds_creation(thresholds)
)

# Iterate over the columns and create a validation step for each
for column in columns:
val_info = _ValidationInfo(
# TODO: should type hint these as required args i think
assertion_type="col_vals_pct_null",
column=column,
values={"p": p, "bound_finder": bound_finder},
brief=brief,
active=True,
thresholds=thresholds,
)

self._add_validation(validation_info=val_info)

return self

def col_vals_gt(
self,
columns: str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals,
Expand Down Expand Up @@ -10335,7 +10386,7 @@ def interrogate(

start_time = datetime.datetime.now(datetime.timezone.utc)

assertion_type = validation.assertion_type
assertion_type: str = validation.assertion_type
column = validation.column
value = validation.values
inclusive = validation.inclusive
Expand Down Expand Up @@ -10655,6 +10706,20 @@ def interrogate(

results_tbl = None

elif assertion_type == "col_vals_pct_null":
results_bool: bool = col_vals_pct_null(
data_tbl=data_tbl_step,
column=column,
p=value["p"],
bound_finder=value["bound_finder"],
)
validation.all_passed = results_bool
validation.n = 1
validation.n_passed = int(results_bool)
validation.n_failed = 1 - int(results_bool)

results_tbl = None

elif assertion_type == "conjointly":
results_tbl = conjointly_validation(
data_tbl=data_tbl_step,
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import polars as pl
import pytest


@pytest.fixture
def half_null_ser() -> pl.Series:
"""A 1k element half null series. Exists to get around rounding issues."""
data = [None if i % 2 == 0 else i for i in range(1000)]
return pl.Series("half_null", data)
123 changes: 120 additions & 3 deletions tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5840,9 +5840,6 @@ def test_row_count_example_tol() -> None:
)


test_row_count_example_tol()


@pytest.mark.parametrize(
("nrows", "target_count", "tol", "should_pass"),
[
Expand Down Expand Up @@ -17147,6 +17144,126 @@ def test_original_table_never_modified_without_pre(request, tbl_fixture):
assert all(n == expected_rows for n in n_values)


def test_pct_null_simple() -> None:
"""Test col_vals_pct_null() with simple data."""
data = pl.DataFrame({"a": [1, None, 3, None], "b": [None, None, 3, 4]})
validation = Validate(data).col_vals_pct_null(columns=["a", "b"], p=0.5).interrogate()

validation.assert_passing()
validation.assert_below_threshold()

info = validation.validation_info

assert len(info) == 2


def test_pct_null_simple_fail() -> None:
"""Test col_vals_pct_null() with simple data."""
data = pl.DataFrame({"a": [1, None, 3, None], "b": [None, None, 3, 4]})
validation = (
Validate(data)
.col_vals_pct_null(columns=["a", "b"], p=0.1, tol=0.0001, thresholds=1)
.interrogate()
)

with pytest.raises(AssertionError):
validation.assert_passing()

with pytest.raises(AssertionError):
validation.assert_below_threshold()

info = validation.validation_info

assert len(info) == 2


@pytest.mark.xfail(reason="No SVG for pct null?")
def test_pct_null_simple_report() -> None:
"""Test col_vals_pct_null() with simple data."""
data = pl.DataFrame({"a": [1, None, 3, None], "b": [None, None, 3, 4]})
validation = (
Validate(data)
.col_vals_pct_null(columns=["a", "b"], p=0.1, tol=0.0001, thresholds=1)
.interrogate()
)

validation.get_tabular_report()


def test_pct_null_exact_match_with_tol() -> None:
"""Should pass if pct null matches exactly, even with tol."""
data = pl.DataFrame({"a": [None, 1, 2, 3]}) # 25% nulls
validation = Validate(data).col_vals_pct_null(columns=["a"], p=0.25, tol=0.0).interrogate()
validation.assert_passing()


def test_pct_null_within_tol_pass() -> None:
"""Should pass if pct null is within tolerance margin."""
data = pl.DataFrame({"a": [None, None, 1, 2]}) # 50% nulls
# Allow tolerance of 0.1 around 0.4 -> [0.3, 0.5]
validation = Validate(data).col_vals_pct_null(columns=["a"], p=0.4, tol=0.1).interrogate()
validation.assert_passing()


def test_pct_null_outside_tol_fail(half_null_ser: pl.Series) -> None:
"""Should fail if pct null is outside tolerance margin."""
data = pl.DataFrame({"a": half_null_ser}) # 50% nulls
validation = Validate(data).col_vals_pct_null(columns=["a"], p=0.4, tol=0.05).interrogate()
with pytest.raises(AssertionError):
validation.assert_passing()


def test_pct_null_lower_bound_edge() -> None:
"""Should pass exactly at lower bound of tolerance range."""
data = pl.DataFrame({"a": [None, None, 1, 2]}) # 50% nulls
# Expect 0.55 ± 0.05 => [0.5, 0.6]
validation = Validate(data).col_vals_pct_null(columns=["a"], p=0.55, tol=0.0).interrogate()
validation.assert_passing()


def test_pct_null_upper_bound_edge() -> None:
"""Should pass exactly at upper bound of tolerance range."""
data = pl.DataFrame({"a": [None, 1, 2, 3]}) # 25% nulls
# Expect 0.2 ± 0.05 => [0.15, 0.25]
validation = Validate(data).col_vals_pct_null(columns=["a"], p=0.2, tol=0.05).interrogate()
validation.assert_passing()


def test_pct_null_multiple_columns_with_tol() -> None:
"""Should check multiple columns with tolerance."""
data = pl.DataFrame(
{
"a": [None, None, 1, 2], # 50%
"b": [1, None, 2, None], # 50%
"c": [1, 2, 3, 4], # 0%
}
)
validation = (
Validate(data).col_vals_pct_null(columns=["a", "b", "c"], p=0.5, tol=0.01).interrogate()
)
# "a" and "b" should pass, "c" should fail
with pytest.raises(AssertionError):
validation.assert_passing()


def test_pct_null_low_tol(half_null_ser: pl.Series) -> None:
"""Tolerance is subject to rounding, and always relative to the total dataset."""
data = pl.DataFrame({"a": [None, None, 2, 3]}) # 50% null
validation = Validate(data).col_vals_pct_null(columns=["a"], p=0.501, tol=0.0).interrogate()
validation.assert_passing() # the reason this passes is because of rounding

data = pl.DataFrame({"a": half_null_ser})
validation = Validate(data).col_vals_pct_null(columns=["a"], p=0.501, tol=0.0).interrogate()
with pytest.raises(AssertionError):
validation.assert_passing() # now fails because no rounding issues


def test_pct_null_high_tol_always_pass() -> None:
"""Large tolerance should allow big differences."""
data = pl.DataFrame({"a": [None, None, None, 1]}) # 75% null
validation = Validate(data).col_vals_pct_null(columns=["a"], p=0.25, tol=10).interrogate()
validation.assert_passing()

@pytest.fixture
def timezone_datetime_polars():
"""Polars DataFrame with timezone-aware datetime values."""
Expand Down
Loading