diff --git a/pointblank/thresholds.py b/pointblank/thresholds.py index 2c649c3ac..ab604a50f 100644 --- a/pointblank/thresholds.py +++ b/pointblank/thresholds.py @@ -317,6 +317,10 @@ class Actions: critical A string, `Callable`, or list of `Callable`/string values for the 'critical' level. Using `None` means no action should be performed at the 'critical' level. + highest_only + A boolean value that, when set to `True` (the default), results in executing only the action + for the highest threshold level that is exceeded. Useful when you want to ensure that only + the most severe action is taken when multiple threshold levels are exceeded. Returns ------- @@ -438,6 +442,7 @@ def dq_issue(): warning: str | Callable | list[str | Callable] | None = None error: str | Callable | list[str | Callable] | None = None critical: str | Callable | list[str | Callable] | None = None + highest_only: bool = True def __post_init__(self): self.warning = self._ensure_list(self.warning) diff --git a/pointblank/validate.py b/pointblank/validate.py index a451d4877..ce0d9aabb 100644 --- a/pointblank/validate.py +++ b/pointblank/validate.py @@ -5394,9 +5394,9 @@ def interrogate( if collect_tbl_checked and results_tbl is not None: validation.tbl_checked = results_tbl - # Perform any necessary actions if threshold levels are exceeded for each - # of the severity levels ('warning', 'error', 'critical') - for level in ["warning", "error", "critical"]: + # Perform any necessary actions if threshold levels are exceeded for each of + # the severity levels (in descending order of 'critical', 'error', and 'warning') + for level in ["critical", "error", "warning"]: if getattr(validation, level) and ( self.actions is not None or validation.actions is not None ): @@ -5449,6 +5449,9 @@ def interrogate( with _action_context_manager(metadata): act() + if validation.actions.highest_only: + break + elif self.actions is not None: # Action execution on the global level action = self.actions._get_action(level=level) @@ -5489,6 +5492,9 @@ def interrogate( with _action_context_manager(metadata): act() + if self.actions.highest_only: + break + # If this is a row-based validation step, then extract the rows that failed # TODO: Add support for extraction of rows for Ibis backends if ( @@ -7966,6 +7972,10 @@ def _process_action_str( # If a `col` value is available for the validation step *and* the action string contains a # placeholder for the column name then replace with `col`; placeholders are: {col} and {column} if col is not None: + # If a list of columns is provided, then join the columns into a comma-separated string + if isinstance(col, list): + col = ", ".join(col) + action_str = action_str.replace("{col}", col) action_str = action_str.replace("{column}", col) diff --git a/tests/test_validate.py b/tests/test_validate.py index 93dae3c3f..c910e8275 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -1571,6 +1571,104 @@ def test_validation_actions_step_only_none(request, tbl_fixture, capsys): assert captured.out == "" +@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"]) +def test_validation_actions_global_highest(tbl_type, capsys): + ( + Validate( + data=load_dataset(dataset="small_table", tbl_type=tbl_type), + thresholds=Thresholds(warning=1, error=2, critical=3), + actions=Actions( + warning="W_global", error="E_global", critical="C_global", highest_only=True + ), + ) + .col_vals_gt(columns="d", value=10000) + .interrogate() + ) + + # Capture the output and verify that only the highest priority level + # message printed to the console + captured = capsys.readouterr() + assert "C_global" in captured.out + assert "E_global" not in captured.out + assert "W_global" not in captured.out + + +@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"]) +def test_validation_actions_global_all(tbl_type, capsys): + ( + Validate( + data=load_dataset(dataset="small_table", tbl_type=tbl_type), + thresholds=Thresholds(warning=1, error=2, critical=3), + actions=Actions( + warning="W_global", error="E_global", critical="C_global", highest_only=False + ), + ) + .col_vals_gt(columns="d", value=10000) + .interrogate() + ) + + # Capture the output and verify that all three level messages are printed to the console + captured = capsys.readouterr() + assert "C_global" in captured.out + assert "E_global" in captured.out + assert "W_global" in captured.out + + +@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"]) +def test_validation_actions_local_highest(tbl_type, capsys): + ( + Validate( + data=load_dataset(dataset="small_table", tbl_type=tbl_type), + thresholds=Thresholds(warning=1, error=2, critical=3), + actions=Actions( + warning="W_global", error="E_global", critical="C_global", highest_only=False + ), + ) + .col_vals_gt( + columns="d", + value=10000, + actions=Actions( + warning="W_local", error="E_local", critical="C_local", highest_only=True + ), + ) + .interrogate() + ) + + # Capture the output and verify that only the highest priority level + # message printed to the console + captured = capsys.readouterr() + assert "C_local" in captured.out + assert "E_local" not in captured.out + assert "W_local" not in captured.out + + +@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"]) +def test_validation_actions_local_all(tbl_type, capsys): + ( + Validate( + data=load_dataset(dataset="small_table", tbl_type=tbl_type), + thresholds=Thresholds(warning=1, error=2, critical=3), + actions=Actions( + warning="W_global", error="E_global", critical="C_global", highest_only=True + ), + ) + .col_vals_gt( + columns="d", + value=10000, + actions=Actions( + warning="W_local", error="E_local", critical="C_local", highest_only=False + ), + ) + .interrogate() + ) + + # Capture the output and verify that all three level messages are printed to the console + captured = capsys.readouterr() + assert "C_local" in captured.out + assert "E_local" in captured.out + assert "W_local" in captured.out + + @pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"]) def test_validation_actions_get_action_metadata(tbl_type, capsys): def log_issue():