Skip to content

Commit 4fc60aa

Browse files
ref: ensure_integration_enabled without original function (#2893)
ensure_integration_enabled and ensure_integration_enabled_async can now decorate functions that return None without an original function. --------- Co-authored-by: Ivana Kellyerova <[email protected]>
1 parent fa17f3b commit 4fc60aa

File tree

3 files changed

+155
-7
lines changed

3 files changed

+155
-7
lines changed

sentry_sdk/utils.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@
3838
from typing import (
3939
Any,
4040
Callable,
41+
cast,
4142
ContextManager,
4243
Dict,
4344
Iterator,
4445
List,
4546
NoReturn,
4647
Optional,
48+
overload,
4749
ParamSpec,
4850
Set,
4951
Tuple,
@@ -1631,9 +1633,39 @@ def reraise(tp, value, tb=None):
16311633
raise value
16321634

16331635

1636+
def _no_op(*_a, **_k):
1637+
# type: (*Any, **Any) -> None
1638+
"""No-op function for ensure_integration_enabled."""
1639+
pass
1640+
1641+
1642+
async def _no_op_async(*_a, **_k):
1643+
# type: (*Any, **Any) -> None
1644+
"""No-op function for ensure_integration_enabled_async."""
1645+
pass
1646+
1647+
1648+
if TYPE_CHECKING:
1649+
1650+
@overload
1651+
def ensure_integration_enabled(
1652+
integration, # type: type[sentry_sdk.integrations.Integration]
1653+
original_function, # type: Callable[P, R]
1654+
):
1655+
# type: (...) -> Callable[[Callable[P, R]], Callable[P, R]]
1656+
...
1657+
1658+
@overload
1659+
def ensure_integration_enabled(
1660+
integration, # type: type[sentry_sdk.integrations.Integration]
1661+
):
1662+
# type: (...) -> Callable[[Callable[P, None]], Callable[P, None]]
1663+
...
1664+
1665+
16341666
def ensure_integration_enabled(
16351667
integration, # type: type[sentry_sdk.integrations.Integration]
1636-
original_function, # type: Callable[P, R]
1668+
original_function=_no_op, # type: Union[Callable[P, R], Callable[P, None]]
16371669
):
16381670
# type: (...) -> Callable[[Callable[P, R]], Callable[P, R]]
16391671
"""
@@ -1657,25 +1689,51 @@ def patch_my_function():
16571689
return my_function()
16581690
```
16591691
"""
1692+
if TYPE_CHECKING:
1693+
# Type hint to ensure the default function has the right typing. The overloads
1694+
# ensure the default _no_op function is only used when R is None.
1695+
original_function = cast(Callable[P, R], original_function)
16601696

16611697
def patcher(sentry_patched_function):
16621698
# type: (Callable[P, R]) -> Callable[P, R]
1663-
@wraps(original_function)
16641699
def runner(*args: "P.args", **kwargs: "P.kwargs"):
16651700
# type: (...) -> R
16661701
if sentry_sdk.get_client().get_integration(integration) is None:
16671702
return original_function(*args, **kwargs)
16681703

16691704
return sentry_patched_function(*args, **kwargs)
16701705

1671-
return runner
1706+
if original_function is _no_op:
1707+
return wraps(sentry_patched_function)(runner)
1708+
1709+
return wraps(original_function)(runner)
16721710

16731711
return patcher
16741712

16751713

1676-
def ensure_integration_enabled_async(
1714+
if TYPE_CHECKING:
1715+
1716+
# mypy has some trouble with the overloads, hence the ignore[no-overload-impl]
1717+
@overload # type: ignore[no-overload-impl]
1718+
def ensure_integration_enabled_async(
1719+
integration, # type: type[sentry_sdk.integrations.Integration]
1720+
original_function, # type: Callable[P, Awaitable[R]]
1721+
):
1722+
# type: (...) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]
1723+
...
1724+
1725+
@overload
1726+
def ensure_integration_enabled_async(
1727+
integration, # type: type[sentry_sdk.integrations.Integration]
1728+
):
1729+
# type: (...) -> Callable[[Callable[P, Awaitable[None]]], Callable[P, Awaitable[None]]]
1730+
...
1731+
1732+
1733+
# The ignore[no-redef] also needed because mypy is struggling with these overloads.
1734+
def ensure_integration_enabled_async( # type: ignore[no-redef]
16771735
integration, # type: type[sentry_sdk.integrations.Integration]
1678-
original_function, # type: Callable[P, Awaitable[R]]
1736+
original_function=_no_op_async, # type: Union[Callable[P, Awaitable[R]], Callable[P, Awaitable[None]]]
16791737
):
16801738
# type: (...) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]
16811739
"""
@@ -1684,17 +1742,24 @@ def ensure_integration_enabled_async(
16841742
Please refer to the `ensure_integration_enabled` documentation for more information.
16851743
"""
16861744

1745+
if TYPE_CHECKING:
1746+
# Type hint to ensure the default function has the right typing. The overloads
1747+
# ensure the default _no_op function is only used when R is None.
1748+
original_function = cast(Callable[P, Awaitable[R]], original_function)
1749+
16871750
def patcher(sentry_patched_function):
16881751
# type: (Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]
1689-
@wraps(original_function)
16901752
async def runner(*args: "P.args", **kwargs: "P.kwargs"):
16911753
# type: (...) -> R
16921754
if sentry_sdk.get_client().get_integration(integration) is None:
16931755
return await original_function(*args, **kwargs)
16941756

16951757
return await sentry_patched_function(*args, **kwargs)
16961758

1697-
return runner
1759+
if original_function is _no_op_async:
1760+
return wraps(sentry_patched_function)(runner)
1761+
1762+
return wraps(original_function)(runner)
16981763

16991764
return patcher
17001765

tests/test_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ def function_to_patch():
594594
)
595595

596596
assert patched_function() == "patched"
597+
assert patched_function.__name__ == "original_function"
597598

598599

599600
def test_ensure_integration_enabled_integration_disabled(sentry_init):
@@ -611,6 +612,41 @@ def function_to_patch():
611612
)
612613

613614
assert patched_function() == "original"
615+
assert patched_function.__name__ == "original_function"
616+
617+
618+
def test_ensure_integration_enabled_no_original_function_enabled(sentry_init):
619+
shared_variable = "original"
620+
621+
def function_to_patch():
622+
nonlocal shared_variable
623+
shared_variable = "patched"
624+
625+
sentry_init(integrations=[TestIntegration])
626+
627+
# Test the decorator by applying to function_to_patch
628+
patched_function = ensure_integration_enabled(TestIntegration)(function_to_patch)
629+
patched_function()
630+
631+
assert shared_variable == "patched"
632+
assert patched_function.__name__ == "function_to_patch"
633+
634+
635+
def test_ensure_integration_enabled_no_original_function_disabled(sentry_init):
636+
shared_variable = "original"
637+
638+
def function_to_patch():
639+
nonlocal shared_variable
640+
shared_variable = "patched"
641+
642+
sentry_init(integrations=[])
643+
644+
# Test the decorator by applying to function_to_patch
645+
patched_function = ensure_integration_enabled(TestIntegration)(function_to_patch)
646+
patched_function()
647+
648+
assert shared_variable == "original"
649+
assert patched_function.__name__ == "function_to_patch"
614650

615651

616652
@pytest.mark.asyncio
@@ -630,6 +666,7 @@ async def function_to_patch():
630666
)(function_to_patch)
631667

632668
assert await patched_function() == "patched"
669+
assert patched_function.__name__ == "original_function"
633670

634671

635672
@pytest.mark.asyncio
@@ -649,3 +686,48 @@ async def function_to_patch():
649686
)(function_to_patch)
650687

651688
assert await patched_function() == "original"
689+
assert patched_function.__name__ == "original_function"
690+
691+
692+
@pytest.mark.asyncio
693+
async def test_ensure_integration_enabled_async_no_original_function_enabled(
694+
sentry_init,
695+
):
696+
shared_variable = "original"
697+
698+
async def function_to_patch():
699+
nonlocal shared_variable
700+
shared_variable = "patched"
701+
702+
sentry_init(integrations=[TestIntegration])
703+
704+
# Test the decorator by applying to function_to_patch
705+
patched_function = ensure_integration_enabled_async(TestIntegration)(
706+
function_to_patch
707+
)
708+
await patched_function()
709+
710+
assert shared_variable == "patched"
711+
assert patched_function.__name__ == "function_to_patch"
712+
713+
714+
@pytest.mark.asyncio
715+
async def test_ensure_integration_enabled_async_no_original_function_disabled(
716+
sentry_init,
717+
):
718+
shared_variable = "original"
719+
720+
async def function_to_patch():
721+
nonlocal shared_variable
722+
shared_variable = "patched"
723+
724+
sentry_init(integrations=[])
725+
726+
# Test the decorator by applying to function_to_patch
727+
patched_function = ensure_integration_enabled_async(TestIntegration)(
728+
function_to_patch
729+
)
730+
await patched_function()
731+
732+
assert shared_variable == "original"
733+
assert patched_function.__name__ == "function_to_patch"

tests/tracing/test_decorator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ async def my_async_example_function():
1515
return "return_of_async_function"
1616

1717

18+
@pytest.mark.forked
1819
def test_trace_decorator():
1920
with patch_start_tracing_child() as fake_start_child:
2021
result = my_example_function()

0 commit comments

Comments
 (0)