Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pointblank/thresholds.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ class Actions:
critical
A string, `Callable`, or list of `Callable`/string values for the 'critical' level. Using
`None` means no action should be performed at the 'critical' level.
highest_only
A boolean value that, when set to `True` (the default), results in executing only the action
for the highest threshold level that is exceeded. Useful when you want to ensure that only
the most severe action is taken when multiple threshold levels are exceeded.

Returns
-------
Expand Down Expand Up @@ -438,6 +442,7 @@ def dq_issue():
warning: str | Callable | list[str | Callable] | None = None
error: str | Callable | list[str | Callable] | None = None
critical: str | Callable | list[str | Callable] | None = None
highest_only: bool = True

def __post_init__(self):
self.warning = self._ensure_list(self.warning)
Expand Down
16 changes: 13 additions & 3 deletions pointblank/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5394,9 +5394,9 @@ def interrogate(
if collect_tbl_checked and results_tbl is not None:
validation.tbl_checked = results_tbl

# Perform any necessary actions if threshold levels are exceeded for each
# of the severity levels ('warning', 'error', 'critical')
for level in ["warning", "error", "critical"]:
# Perform any necessary actions if threshold levels are exceeded for each of
# the severity levels (in descending order of 'critical', 'error', and 'warning')
for level in ["critical", "error", "warning"]:
if getattr(validation, level) and (
self.actions is not None or validation.actions is not None
):
Expand Down Expand Up @@ -5449,6 +5449,9 @@ def interrogate(
with _action_context_manager(metadata):
act()

if validation.actions.highest_only:
break

elif self.actions is not None:
# Action execution on the global level
action = self.actions._get_action(level=level)
Expand Down Expand Up @@ -5489,6 +5492,9 @@ def interrogate(
with _action_context_manager(metadata):
act()

if self.actions.highest_only:
break

# If this is a row-based validation step, then extract the rows that failed
# TODO: Add support for extraction of rows for Ibis backends
if (
Expand Down Expand Up @@ -7966,6 +7972,10 @@ def _process_action_str(
# If a `col` value is available for the validation step *and* the action string contains a
# placeholder for the column name then replace with `col`; placeholders are: {col} and {column}
if col is not None:
# If a list of columns is provided, then join the columns into a comma-separated string
if isinstance(col, list):
col = ", ".join(col)

action_str = action_str.replace("{col}", col)
action_str = action_str.replace("{column}", col)

Expand Down
98 changes: 98 additions & 0 deletions tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,104 @@ def test_validation_actions_step_only_none(request, tbl_fixture, capsys):
assert captured.out == ""


@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"])
def test_validation_actions_global_highest(tbl_type, capsys):
(
Validate(
data=load_dataset(dataset="small_table", tbl_type=tbl_type),
thresholds=Thresholds(warning=1, error=2, critical=3),
actions=Actions(
warning="W_global", error="E_global", critical="C_global", highest_only=True
),
)
.col_vals_gt(columns="d", value=10000)
.interrogate()
)

# Capture the output and verify that only the highest priority level
# message printed to the console
captured = capsys.readouterr()
assert "C_global" in captured.out
assert "E_global" not in captured.out
assert "W_global" not in captured.out


@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"])
def test_validation_actions_global_all(tbl_type, capsys):
(
Validate(
data=load_dataset(dataset="small_table", tbl_type=tbl_type),
thresholds=Thresholds(warning=1, error=2, critical=3),
actions=Actions(
warning="W_global", error="E_global", critical="C_global", highest_only=False
),
)
.col_vals_gt(columns="d", value=10000)
.interrogate()
)

# Capture the output and verify that all three level messages are printed to the console
captured = capsys.readouterr()
assert "C_global" in captured.out
assert "E_global" in captured.out
assert "W_global" in captured.out


@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"])
def test_validation_actions_local_highest(tbl_type, capsys):
(
Validate(
data=load_dataset(dataset="small_table", tbl_type=tbl_type),
thresholds=Thresholds(warning=1, error=2, critical=3),
actions=Actions(
warning="W_global", error="E_global", critical="C_global", highest_only=False
),
)
.col_vals_gt(
columns="d",
value=10000,
actions=Actions(
warning="W_local", error="E_local", critical="C_local", highest_only=True
),
)
.interrogate()
)

# Capture the output and verify that only the highest priority level
# message printed to the console
captured = capsys.readouterr()
assert "C_local" in captured.out
assert "E_local" not in captured.out
assert "W_local" not in captured.out


@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"])
def test_validation_actions_local_all(tbl_type, capsys):
(
Validate(
data=load_dataset(dataset="small_table", tbl_type=tbl_type),
thresholds=Thresholds(warning=1, error=2, critical=3),
actions=Actions(
warning="W_global", error="E_global", critical="C_global", highest_only=True
),
)
.col_vals_gt(
columns="d",
value=10000,
actions=Actions(
warning="W_local", error="E_local", critical="C_local", highest_only=False
),
)
.interrogate()
)

# Capture the output and verify that all three level messages are printed to the console
captured = capsys.readouterr()
assert "C_local" in captured.out
assert "E_local" in captured.out
assert "W_local" in captured.out


@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"])
def test_validation_actions_get_action_metadata(tbl_type, capsys):
def log_issue():
Expand Down