Skip to content

Commit 593b634

Browse files
authored
Merge pull request #662 from aai-institute/feature/better-warnings-filter
Allow suppress_warnings to raise exceptions / any other valid action
2 parents af6752e + c65ff32 commit 593b634

File tree

3 files changed

+84
-43
lines changed

3 files changed

+84
-43
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@
5858

5959
### Fixed
6060

61-
- Fixed `show_warnings=False` not being respected in subprocesses
61+
- Fixed `show_warnings=False` not being respected in subprocesses. Introduced
62+
`suppress_warninigs` decorator for more flexibility
6263
[PR #647](https://github.com/aai-institute/pyDVL/pull/647)
64+
[PR #662](https://github.com/aai-institute/pyDVL/pull/662)
6365
- Fixed several bugs in diverse stopping criteria, including: iteration counts,
6466
computing completion, resetting, nested composition
6567
[PR #641](https://github.com/aai-institute/pyDVL/pull/641)

src/pydvl/utils/functional.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -163,44 +163,63 @@ def suppress_warnings(
163163
) -> Union[Callable[[Callable[P, R]], Callable[P, R]], Callable[P, R]]:
164164
"""Decorator for class methods to conditionally suppress warnings.
165165
166-
The decorated method will execute with warnings suppressed for the specified
167-
categories. If the instance has the attribute named by `flag`, and it evaluates to
168-
`True`, then suppression will be deactivated.
169-
170-
??? Example "Suppress all warnings"
171-
```python
172-
class A:
173-
@suppress_warnings
174-
def method(self, ...):
175-
...
176-
```
177-
??? Example "Suppress only `UserWarning`"
178-
```python
179-
class A:
180-
@suppress_warnings(categories=(UserWarning,))
181-
def method(self, ...):
182-
...
183-
```
184-
??? Example "Configuring behaviour at runtime"
185-
```python
186-
class A:
187-
def __init__(self, warn_enabled: bool):
188-
self.warn_enabled = warn_enabled
189-
190-
@suppress_warnings(flag="warn_enabled")
191-
def method(self, ...):
192-
...
193-
```
194-
195-
Args:
196-
fun: Optional callable to decorate. If provided, the decorator is applied inline.
197-
categories: Sequence of warning categories to suppress.
198-
flag: Name of an instance attribute to check for enabling warnings. If the
199-
attribute exists and evaluates to `True`, warnings will **not** be
200-
suppressed.
201-
202-
Returns:
203-
Either a decorator (if no function is provided) or the decorated callable.
166+
The decorated method will execute with warnings suppressed for the specified
167+
categories. If the instance has the attribute named by `flag`, and it's a boolean
168+
evaluating to `False`, warnings will be ignored. If the attribute is a string, then
169+
it is interpreted as an "action" to be performed on the categories specified.
170+
Allowed values are as per [warnings.simplefilter][], which are:
171+
`default`, `error`, `ignore`, `always`, `all`, `module`, `once`
172+
173+
??? Example "Suppress all warnings"
174+
```python
175+
class A:
176+
@suppress_warnings
177+
def method(self, ...):
178+
...
179+
```
180+
??? Example "Suppress only `UserWarning`"
181+
```python
182+
class A:
183+
@suppress_warnings(categories=(UserWarning,))
184+
def method(self, ...):
185+
...
186+
```
187+
??? Example "Configuring behaviour at runtime"
188+
```python
189+
class A:
190+
def __init__(self, warn_enabled: bool):
191+
self.warn_enabled = warn_enabled
192+
193+
@suppress_warnings(flag="warn_enabled")
194+
def method(self, ...):
195+
...
196+
```
197+
198+
??? Example "Raising on RuntimeWarning"
199+
```python
200+
class A:
201+
def __init__(self, warnings: str = "error"):
202+
self.warnings = warnings
203+
204+
@suppress_warnings(flag="warnings")
205+
def method(self, ...):
206+
...
207+
208+
A().method() # Raises RuntimeWarning
209+
```
210+
211+
212+
Args:
213+
fun: Optional callable to decorate. If provided, the decorator is applied inline.
214+
categories: Sequence of warning categories to suppress.
215+
flag: Name of an instance attribute to check for enabling warnings. If the
216+
attribute exists and evaluates to `False`, warnings will be ignored. If
217+
it evaluates to a str, then this action will be performed on the categories
218+
specified. Allowed values are as per [warnings.simplefilter][], which are:
219+
`default`, `error`, `ignore`, `always`, `all`, `module`, `once`
220+
221+
Returns:
222+
Either a decorator (if no function is provided) or the decorated callable.
204223
"""
205224

206225
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
@@ -227,11 +246,18 @@ def wrapper(self, *args: Any, **kwargs: Any) -> R:
227246
raise AttributeError(
228247
f"Instance has no attribute '{flag}' for suppress_warnings"
229248
)
230-
if flag and getattr(self, flag, False):
249+
if flag and getattr(self, flag, False) is True:
231250
return fn(self, *args, **kwargs)
251+
# flag is either False or a string
232252
with warnings.catch_warnings():
253+
if (action := getattr(self, flag, "ignore")) is False:
254+
action = "ignore"
255+
elif not isinstance(action, str):
256+
raise TypeError(
257+
f"Expected a boolean or string for flag '{flag}', got {type(action).__name__}"
258+
)
233259
for category in categories:
234-
warnings.simplefilter("ignore", category=category)
260+
warnings.simplefilter(action, category=category) # type: ignore
235261
return fn(self, *args, **kwargs)
236262

237263
return cast(Callable[P, R], wrapper)

tests/utils/test_functional.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import gc
1+
from __future__ import annotations
2+
23
import inspect
34
import logging
45
import time
@@ -13,7 +14,7 @@
1314

1415

1516
class WarningsClass:
16-
def __init__(self, show_warnings: bool = True):
17+
def __init__(self, show_warnings: str | bool = True):
1718
self.show_warnings = show_warnings
1819

1920
@suppress_warnings(categories=(UserWarning,), flag="show_warnings")
@@ -99,6 +100,18 @@ def test_different_categories(
99100
assert warning_message in str(w.message)
100101

101102

103+
def test_raises_on_flag_error():
104+
obj = WarningsClass(show_warnings="error")
105+
with pytest.raises(UserWarning):
106+
obj.method_warn()
107+
108+
109+
def test_invalid_flag_type():
110+
obj = WarningsClass(show_warnings=42)
111+
with pytest.raises(TypeError):
112+
obj.method_warn()
113+
114+
102115
def test_nonmethod_decorator_usage():
103116
@suppress_warnings(categories=(RuntimeWarning,))
104117
def fun(x: int) -> float:

0 commit comments

Comments
 (0)