diff --git a/pina/type_checker.py b/pina/type_checker.py new file mode 100644 index 000000000..e8c908ac9 --- /dev/null +++ b/pina/type_checker.py @@ -0,0 +1,93 @@ +"""Module for enforcing type hints in Python functions.""" + +import inspect +import typing +import logging + + +def enforce_types(func): + """ + Function decorator to enforce type hints at runtime. + + This decorator checks the types of the arguments and of the return value of + the decorated function against the type hints specified in the function + signature. If the types do not match, a TypeError is raised. + Type checking is only performed when the logging level is set to `DEBUG`. + + :param Callable func: The function to be decorated. + :return: The decorated function with enforced type hints. + :rtype: Callable + + :Example: + + >>> @enforce_types + def dummy_function(a: int, b: float) -> float: + ... return a+b + + # This always works. + dummy_function(1, 2.0) + + # This raises a TypeError for the second argument, if logging is set to + # `DEBUG`. + dummy_function(1, "Hello, world!") + + + >>> @enforce_types + def dummy_function2(a: int, right: bool) -> float: + ... if right: + ... return float(a) + ... else: + ... return "Hello, world!" + + # This always works. + dummy_function2(1, right=True) + + # This raises a TypeError for the return value if logging is set to + # `DEBUG`. + dummy_function2(1, right=False) + """ + + def wrapper(*args, **kwargs): + """ + Wrapper function to enforce type hints. + + :param tuple args: Positional arguments passed to the function. + :param dict kwargs: Keyword arguments passed to the function. + :raises TypeError: If the argument or return type does not match the + specified type hints. + :return: The result of the decorated function. + :rtype: Any + """ + level = logging.getLevelName(logging.getLogger().getEffectiveLevel()) + + # Enforce type hints only in debug mode + if level != "DEBUG": + return func(*args, **kwargs) + + # Get the type hints for the function arguments + hints = typing.get_type_hints(func) + sig = inspect.signature(func) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + for arg_name, arg_value in bound.arguments.items(): + expected_type = hints.get(arg_name) + if expected_type and not isinstance(arg_value, expected_type): + raise TypeError( + f"Argument '{arg_name}' must be {expected_type.__name__}, " + f"but got {type(arg_value).__name__}!" + ) + + # Get the type hints for the return values + return_type = hints.get("return") + result = func(*args, **kwargs) + + if return_type and not isinstance(result, return_type): + raise TypeError( + f"Return value must be {return_type.__name__}, " + f"but got {type(result).__name__}!" + ) + + return result + + return wrapper diff --git a/tests/test_type_checker.py b/tests/test_type_checker.py new file mode 100644 index 000000000..554d9613b --- /dev/null +++ b/tests/test_type_checker.py @@ -0,0 +1,55 @@ +import pytest +import logging +import math +from pina.type_checker import enforce_types + + +# Definition of a test function for arguments +@enforce_types +def foo_function1(a: int, b: float) -> float: + return a + b + + +# Definition of a test function for return values +@enforce_types +def foo_function2(a: int, right: bool) -> float: + if right: + return float(a) + else: + return "Hello, world!" + + +def test_argument_type_checking(): + + # Setting logging level to INFO, which should not trigger type checking + logging.getLogger().setLevel(logging.INFO) + + # Both should work, even if the arguments are not of the expected type + assert math.isclose(foo_function1(a=1, b=2.0), 3.0) + assert math.isclose(foo_function1(a=1, b=2), 3.0) + + # Setting logging level to DEBUG, which should trigger type checking + logging.getLogger().setLevel(logging.DEBUG) + + # The second should fail, as the second argument is an int + assert math.isclose(foo_function1(a=1, b=2.0), 3.0) + with pytest.raises(TypeError): + foo_function1(a=1, b=2) + + +def test_return_type_checking(): + + # Setting logging level to INFO, which should not trigger type checking + logging.getLogger().setLevel(logging.INFO) + + # Both should work, even if the return value is not of the expected type + assert math.isclose(foo_function2(a=1, right=True), 1.0) + assert foo_function2(a=1, right=False) == "Hello, world!" + + # Setting logging level to DEBUG, which should trigger type checking + logging.getLogger().setLevel(logging.DEBUG) + + # The second should fail, as the return value is a string + assert math.isclose(foo_function2(a=1, right=True), 1.0) + with pytest.raises(TypeError): + foo_function2(a=1, right=False)