diff --git a/pointblank/thresholds.py b/pointblank/thresholds.py index ab604a50f..e941a01e6 100644 --- a/pointblank/thresholds.py +++ b/pointblank/thresholds.py @@ -304,7 +304,9 @@ class Actions: to different levels of severity when a threshold is reached. Those thresholds can be defined using the [`Thresholds`](`pointblank.Thresholds`) class or various shorthand forms. Actions don't have to be defined for all threshold levels; if an action is not defined for a level in - exceedance, no action will be taken. + exceedance, no action will be taken. Likewise, there is no negative consequence (other than a + no-op) for defining actions for thresholds that don't exist (e.g., setting an action for the + 'critical' level when no corresponding 'critical' threshold has been set). Parameters ---------- @@ -317,6 +319,10 @@ class Actions: critical A string, `Callable`, or list of `Callable`/string values for the 'critical' level. Using `None` means no action should be performed at the 'critical' level. + default + A string, `Callable`, or list of `Callable`/string values for all threshold levels. This + parameter can be used to set the same action for all threshold levels. If an action is + defined for a specific threshold level, it will override the action set for all levels. highest_only A boolean value that, when set to `True` (the default), results in executing only the action for the highest threshold level that is exceeded. Useful when you want to ensure that only @@ -442,6 +448,7 @@ def dq_issue(): warning: str | Callable | list[str | Callable] | None = None error: str | Callable | list[str | Callable] | None = None critical: str | Callable | list[str | Callable] | None = None + default: str | Callable | list[str | Callable] | None = None highest_only: bool = True def __post_init__(self): @@ -449,6 +456,17 @@ def __post_init__(self): self.error = self._ensure_list(self.error) self.critical = self._ensure_list(self.critical) + if self.default is not None: + self.default = self._ensure_list(self.default) + + # For any unset threshold level, set the default action + if self.warning is None: + self.warning = self.default + if self.error is None: + self.error = self.default + if self.critical is None: + self.critical = self.default + def _ensure_list( self, value: str | Callable | list[str | Callable] | None ) -> list[str | Callable]: diff --git a/tests/test_validate.py b/tests/test_validate.py index c910e8275..326cf49d7 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -1669,6 +1669,100 @@ def test_validation_actions_local_all(tbl_type, capsys): assert "W_local" in captured.out +@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"]) +def test_validation_actions_default_global(tbl_type, capsys): + ( + Validate( + data=load_dataset(dataset="small_table", tbl_type=tbl_type), + thresholds=Thresholds(warning=1, error=2, critical=3), + actions=Actions(default="{level} default_action", highest_only=False), + ) + .col_vals_gt(columns="d", value=10000) + .interrogate() + ) + + # Capture the output and verify that all three level messages are printed to the console + captured = capsys.readouterr() + assert "critical default_action" in captured.out + assert "error default_action" in captured.out + assert "warning default_action" in captured.out + + +@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"]) +def test_validation_actions_default_global_override(tbl_type, capsys): + ( + Validate( + data=load_dataset(dataset="small_table", tbl_type=tbl_type), + thresholds=Thresholds(warning=1, error=2, critical=3), + actions=Actions( + warning="warning override", default="{level} default_action", highest_only=False + ), + ) + .col_vals_gt(columns="d", value=10000) + .interrogate() + ) + + # Capture the output and verify that all three level messages are printed to the console + captured = capsys.readouterr() + assert "critical default_action" in captured.out + assert "error default_action" in captured.out + assert "warning override" in captured.out + + +@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"]) +def test_validation_actions_default_local(tbl_type, capsys): + ( + Validate( + data=load_dataset(dataset="small_table", tbl_type=tbl_type), + thresholds=Thresholds(warning=1, error=2, critical=3), + actions=Actions(default="{level} default_action_global", highest_only=False), + ) + .col_vals_gt( + columns="d", + value=10000, + actions=Actions(default="{level} default_action_local", highest_only=False), + ) + .interrogate() + ) + + # Capture the output and verify that all three level messages are printed to the console + captured = capsys.readouterr() + assert "critical default_action_local" in captured.out + assert "error default_action_local" in captured.out + assert "warning default_action_local" in captured.out + + +@pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"]) +def test_validation_actions_default_local_override(tbl_type, capsys): + ( + Validate( + data=load_dataset(dataset="small_table", tbl_type=tbl_type), + thresholds=Thresholds(warning=1, error=2, critical=3), + actions=Actions( + warning="warning override_global", + default="{level} default_action_global", + highest_only=False, + ), + ) + .col_vals_gt( + columns="d", + value=10000, + actions=Actions( + warning="warning override_local", + default="{level} default_action_local", + highest_only=False, + ), + ) + .interrogate() + ) + + # Capture the output and verify that all three level messages are printed to the console + captured = capsys.readouterr() + assert "critical default_action_local" in captured.out + assert "error default_action_local" in captured.out + assert "warning override_local" in captured.out + + @pytest.mark.parametrize("tbl_type", ["pandas", "polars", "duckdb"]) def test_validation_actions_get_action_metadata(tbl_type, capsys): def log_issue():