Skip to content

Commit 681acd4

Browse files
authored
Merge pull request #121 from AllenNeuralDynamics:feat-run-if
Add `run_if` decorator to skip execution given a predicate
2 parents 362f33a + 6ce1abd commit 681acd4

File tree

3 files changed

+75
-2
lines changed

3 files changed

+75
-2
lines changed

src/clabe/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.6.8"
1+
__version__ = "0.6.9"
22

33
import logging
44

src/clabe/launcher/_callable_manager.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,30 @@ def wrapper(*args, **kwargs):
194194
return wrapper
195195

196196
return decorator
197+
198+
199+
def run_if(predicate: t.Callable[..., bool]) -> t.Callable[[t.Callable[P, R]], t.Callable[P, Optional[R]]]:
200+
"""
201+
A decorator that only runs the wrapped function if the predicate returns True.
202+
If the predicate returns False, returns None.
203+
204+
Args:
205+
predicate: A callable that takes the same arguments as the wrapped function and returns a boolean.
206+
207+
Returns:
208+
The decorated function that runs only if predicate(*args, **kwargs) is True, else returns None.
209+
"""
210+
211+
def decorator(func: t.Callable[P, R]) -> t.Callable[P, Optional[R]]:
212+
@functools.wraps(func)
213+
def wrapper(*args, **kwargs):
214+
fn_name = getattr(func, "__name__", repr(func))
215+
if predicate(*args, **kwargs):
216+
logger.debug(f"Predicate passed for {fn_name}, executing function")
217+
return func(*args, **kwargs)
218+
logger.debug(f"Predicate failed for {fn_name}, skipping execution")
219+
return None
220+
221+
return wrapper
222+
223+
return decorator

tests/launcher/test_callable_manager.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from clabe.launcher._callable_manager import Promise, _CallableManager, _UnsetType, ignore_errors
6+
from clabe.launcher._callable_manager import Promise, _CallableManager, _UnsetType, ignore_errors, run_if
77

88

99
class TestCallableManager:
@@ -278,3 +278,49 @@ def test_ignore_errors_lambda_function(self, caplog):
278278
assert result2 == "lambda_failed"
279279
# Lambda functions have a generic name
280280
assert "Exception in <lambda>: division by zero" in caplog.text
281+
282+
283+
class TestRunIfDecorator:
284+
def test_run_if_runs_when_predicate_true(self):
285+
def always_true(*args, **kwargs):
286+
return True
287+
288+
@run_if(always_true)
289+
def my_func(x):
290+
return x * 2
291+
292+
assert my_func(3) == 6
293+
294+
def test_run_if_returns_none_when_predicate_false(self):
295+
def always_false(*args, **kwargs):
296+
return False
297+
298+
@run_if(always_false)
299+
def my_func(x):
300+
return x * 2
301+
302+
assert my_func(3) is None
303+
304+
def test_run_if_predicate_depends_on_args(self):
305+
def is_positive(x):
306+
return x > 0
307+
308+
@run_if(is_positive)
309+
def square(x):
310+
return x * x
311+
312+
assert square(2) == 4
313+
assert square(-1) is None
314+
315+
def test_run_if_preserves_function_metadata(self):
316+
def always_true(*args, **kwargs):
317+
return True
318+
319+
@run_if(always_true)
320+
def documented_func(x):
321+
"""This function squares its input."""
322+
return x * x
323+
324+
assert hasattr(documented_func, "__name__")
325+
assert hasattr(documented_func, "__doc__")
326+
assert documented_func.__doc__ == "This function squares its input."

0 commit comments

Comments
 (0)