Skip to content

Commit 707e7e8

Browse files
authored
Merge pull request #1 from hyptocrypto/multi-return-decorator
Multi return decorator
2 parents 2e5f1be + 2667201 commit 707e7e8

File tree

5 files changed

+107
-28
lines changed

5 files changed

+107
-28
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ EnforcedTypeError: Can not assign type: <class 'int'> to attribute 'name'. Must
2222
EnforcedTypeError is thrown because the value of the name attribute on a Person object must always be a string.
2323

2424
### Decorators
25-
There are also some handy decorators to ensure that a method arguments are the right type, or that the method returns the proper type. Using the <b>check_return_type</b> decorator will throw an error if the method returns an invalid type.
25+
There are also some handy decorators to ensure that a method arguments are the right type, or that the method returns the proper type. Using the <b>check_return_types</b> decorator will throw an error if the method returns an invalid type.
2626
```python
27-
from pystrong import check_return_type
27+
from pystrong import check_return_types
2828

29-
@check_return_type(str, int)
29+
@check_return_types(str, int)
3030
def test_success():
3131
return "Hello World", 500
3232

33-
@check_return_type(str, int)
33+
@check_return_types(str, int)
3434
def test_fail():
3535
return 500, {"test": "error"}
3636

src/pystrong/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from .decorators import check_return_type, check_arg_type
1+
from .decorators import check_arg_type, check_return_types
22
from .enforcer import InferredTypeEnforcer, TypeEnforcer
33

4-
__all__ = ["TypeEnforcer", "InferredTypeEnforcer", "check_return_type", "check_arg_type"]
4+
__all__ = [
5+
"TypeEnforcer",
6+
"InferredTypeEnforcer",
7+
"check_return_types",
8+
"check_arg_type",
9+
]

src/pystrong/decorators.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,41 @@
1-
from .exceptions import EnforcedReturnTypeError, EnforcedArgTypeError
1+
from .exceptions import EnforcedArgTypeError, EnforcedReturnTypeError
22

3+
non_iterables = [int, float, bool]
34

4-
def check_return_type(*ret_types):
5+
6+
def check_return_types(*ret_types):
7+
# Make sure all args passed to decorator are a valid python type
8+
if not ret_types:
9+
raise EnforcedReturnTypeError(
10+
"Please pass at least one type to check_return_types"
11+
)
512
for _type in ret_types:
613
if type(_type) is not type:
714
raise EnforcedReturnTypeError(
8-
f"Check return type arguments must be a valid python type. Please check '{_type}'"
15+
"Check return type arguments must be a valid python type. Please check"
16+
f" '{_type}'"
917
)
1018

1119
def wrapper(func):
1220
def inner(*args, **kwargs):
13-
retruns = func(*args, **kwargs)
14-
for ret, ret_type in zip(retruns, ret_types):
15-
if type(ret) != ret_type:
21+
# The return values form the function that is being decorated
22+
rets = func(*args, **kwargs)
23+
func_return_types = parse_return_types(rets)
24+
25+
if len(func_return_types) != len(ret_types):
26+
raise EnforcedReturnTypeError(
27+
"Length of type argument passed to check_return_type is not equal"
28+
f" to length of values returned by function '{func.__name__}'"
29+
)
30+
31+
# Iter over actual returns and expected returns and check for validity
32+
for ret, ret_type in zip(func_return_types, ret_types):
33+
if ret != ret_type:
1634
raise EnforcedReturnTypeError(
17-
f"Function '{func.__name__}' returned '{ret}' of type '{type(ret)}'. Expected return type is '{ret_type}'."
35+
f"Function '{func.__name__}' returned '{func_return_types}'."
36+
f" Expected return type format is '{ret_types}'."
1837
)
38+
1939
return func(*args, **kwargs)
2040

2141
return inner
@@ -42,3 +62,20 @@ def inner(*args):
4262
return inner
4363

4464
return wrapper
65+
66+
67+
def parse_return_types(ret):
68+
# If func returns a actual python type
69+
if type(ret) is type:
70+
return [type]
71+
72+
# If func returns a non iterable such as True, 4, or 23.3
73+
if type(ret) in non_iterables:
74+
return [type(ret)]
75+
76+
# Strings are iterable, but single string should not be iterated
77+
if type(ret) == str:
78+
return [str]
79+
80+
# Return a list of types for the func returns
81+
return [type(r) for r in ret]

src/pystrong/enforcer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,22 @@ def __init__(self, *args, **kwargs):
2222
for attr_name, _type in kwargs.items():
2323
if type(_type) is not type:
2424
raise BadTypeError(
25-
"Key word arguments must be in (attr_name, type) format. Ensure 'type' is a valid python type."
25+
"Key word arguments must be in (attr_name, type) format. Ensure"
26+
" 'type' is a valid python type."
2627
)
2728
self.__dict__.update({TYPE_ATTR_FORMAT.format(attr_name): _type})
2829

2930
def __setattr__(self, key: str, value: Any) -> None:
3031
if not self.__dict__.get(TYPE_ATTR_FORMAT.format(key)):
3132
raise AttributeTypeNotSet(
32-
f"No type for attr '{key}' set. This is most likely due to not calling super in the derived class constructor."
33+
f"No type for attr '{key}' set. This is most likely due to not calling"
34+
" super in the derived class constructor."
3335
)
3436

3537
if type(value) is not getattr(self, f"___{key}_type"):
3638
raise EnforcedTypeError(
37-
f"Cant not assign type: {type(value)} to attribute '{key}'. Must be of type: {getattr(self, f'___{key}_type')}"
39+
f"Cant not assign type: {type(value)} to attribute '{key}'. Must be of"
40+
f" type: {getattr(self, f'___{key}_type')}"
3841
)
3942
super().__setattr__(key, value)
4043

tests/test_decorators.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,65 @@
11
import pytest
2-
from src.pystrong import check_return_type
2+
from src.pystrong import check_return_types
33
from src.pystrong.exceptions import EnforcedReturnTypeError
44

5-
6-
7-
@check_return_type(str)
5+
# Should pass
6+
@check_return_types(str)
87
def return_str():
98
return "hello"
109

11-
@check_return_type(int)
10+
11+
@check_return_types(int)
1212
def return_int():
1313
return 33
1414

15-
@check_return_type(int)
15+
16+
@check_return_types(type)
17+
def return_type():
18+
return bool
19+
20+
21+
@check_return_types(int, str, dict, list)
22+
def return_multiple():
23+
return 33, "hello", {}, []
24+
25+
26+
# Should fail
27+
@check_return_types(int)
1628
def return_wrong_type():
1729
return "hello"
1830

31+
32+
@check_return_types(int, str, dict, list)
33+
def return_wrong_multiple():
34+
return "hello", "hello", 33, 3.34
35+
36+
37+
@check_return_types(int, str, dict)
38+
def return_wrong_lengths():
39+
return 33
40+
41+
1942
with pytest.raises(EnforcedReturnTypeError):
20-
@check_return_type("error")
43+
44+
# Fail on check that all args are valid python types
45+
@check_return_types("error")
2146
def return_error():
22-
return 'error'
47+
return "error"
48+
2349
return_error()
2450

2551

26-
def test_check_return_type():
27-
assert(return_str())
28-
assert(return_int())
29-
52+
def test_check_return_types():
53+
assert return_str()
54+
assert return_int()
55+
assert return_type()
56+
assert return_multiple()
57+
3058
with pytest.raises(EnforcedReturnTypeError):
3159
return_wrong_type()
60+
61+
with pytest.raises(EnforcedReturnTypeError):
62+
return_wrong_lengths()
63+
64+
with pytest.raises(EnforcedReturnTypeError):
65+
return_wrong_multiple()

0 commit comments

Comments
 (0)