Skip to content

Commit 55b9ce4

Browse files
committed
Create helper class decorator
1 parent c99fbad commit 55b9ce4

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

src/common/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
def apply_decorator_to_methods(
2+
decorator, protected_methods: bool = False, private_methods: bool = False
3+
):
4+
"""
5+
Class decorator to apply a given function or coroutine decorator
6+
to all functions and coroutines within a class.
7+
"""
8+
9+
def class_decorator(cls):
10+
for attr_name, attr_value in cls.__dict__.items():
11+
# Check if the attribute is a callable (method or coroutine)
12+
if not callable(attr_value):
13+
continue
14+
15+
if attr_name.startswith(f"_{cls.__name__}__"):
16+
if not private_methods:
17+
continue
18+
19+
elif attr_name.startswith("_"):
20+
if not protected_methods:
21+
continue
22+
23+
# Replace the original callable with the decorated version
24+
setattr(cls, attr_name, decorator(attr_value))
25+
return cls
26+
27+
return class_decorator

tests/common/__init__.py

Whitespace-only changes.

tests/common/test_utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
from common.utils import apply_decorator_to_methods
6+
7+
8+
@pytest.mark.parametrize(
9+
"apply_to_protected_methods",
10+
[
11+
True,
12+
False,
13+
],
14+
)
15+
@pytest.mark.parametrize(
16+
"apply_to_private_methods",
17+
[
18+
True,
19+
False,
20+
],
21+
)
22+
async def test_class_decorator(
23+
apply_to_protected_methods: bool,
24+
apply_to_private_methods: bool,
25+
):
26+
def add_ten_decorator(func):
27+
def wrapper(*args, **kwargs):
28+
result = func(*args, **kwargs)
29+
return result + 10
30+
31+
async def async_wrapper(*args, **kwargs):
32+
result = await func(*args, **kwargs)
33+
return result + 10
34+
35+
return wrapper if not asyncio.iscoroutinefunction(func) else async_wrapper
36+
37+
@apply_decorator_to_methods(
38+
decorator=add_ten_decorator,
39+
protected_methods=apply_to_protected_methods,
40+
private_methods=apply_to_private_methods,
41+
)
42+
class MyClass:
43+
def get_public(self):
44+
return 10
45+
46+
def _get_protected(self):
47+
return 10
48+
49+
def __get_private(self):
50+
return 10
51+
52+
async def get_apublic(self):
53+
return 10
54+
55+
async def _get_aprotected(self):
56+
return 10
57+
58+
async def __get_aprivate(self):
59+
return 10
60+
61+
c = MyClass()
62+
assert c.get_public() == 20
63+
assert c._get_protected() == 20 if apply_to_protected_methods else 10
64+
assert c._MyClass__get_private() == 20 if apply_to_private_methods else 10
65+
assert await c.get_apublic() == 20
66+
assert await c._get_aprotected() == 20 if apply_to_protected_methods else 10
67+
assert await c._MyClass__get_aprivate() == 20 if apply_to_private_methods else 10

0 commit comments

Comments
 (0)