diff --git a/ardupilot_methodic_configurator/safe_eval.py b/ardupilot_methodic_configurator/safe_eval.py new file mode 100644 index 00000000..181d300f --- /dev/null +++ b/ardupilot_methodic_configurator/safe_eval.py @@ -0,0 +1,92 @@ +""" +This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator. + +SPDX-FileCopyrightText: 2024 Amilcar do Carmo Lucas + +SPDX-License-Identifier: GPL-3.0-or-later +""" + +import ast +import logging +import math +import operator +from typing import Callable, Union, cast + +logger = logging.getLogger(__name__) + +# Type aliases +Number = Union[int, float] +MathFunc = Callable[..., Number] +BinOperator = Callable[[Number, Number], Number] +UnOperator = Callable[[Number], Number] + +def safe_eval(s: str) -> Number: + def checkmath(x: str, *args: Number) -> Number: + if x not in [x for x in dir(math) if "__" not in x]: + msg = f"Unknown func {x}()" + raise SyntaxError(msg) + fun = cast(MathFunc, getattr(math, x)) + try: + return fun(*args) + except TypeError as e: + msg = f"Invalid arguments for {x}(): {e!s}" + raise SyntaxError(msg) from e + + bin_ops: dict[type[ast.operator], BinOperator] = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.Call: checkmath, + ast.BinOp: ast.BinOp, + } + + un_ops: dict[type[ast.UnaryOp], UnOperator] = { + ast.USub: operator.neg, + ast.UAdd: operator.pos, + ast.UnaryOp: ast.UnaryOp, + } + + tree = ast.parse(s, mode="eval") + + def _eval(node: ast.AST) -> Number: + if isinstance(node, ast.Expression): + logger.debug("Expr") + return _eval(node.body) + if isinstance(node, ast.Constant): + logger.info("Const") + return cast(Number, node.value) + if isinstance(node, ast.Name): + # Handle math constants like pi, e, etc. + logger.info("MathConst") + if hasattr(math, node.id): + return cast(Number, getattr(math, node.id)) + msg = f"Unknown constant: {node.id}" + raise SyntaxError(msg) + if isinstance(node, ast.BinOp): + logger.debug("BinOp") + left = _eval(node.left) + right = _eval(node.right) + if type(node.op) not in bin_ops: + msg = f"Unsupported operator: {type(node.op)}" + raise SyntaxError(msg) + return bin_ops[type(node.op)](left, right) + if isinstance(node, ast.UnaryOp): + logger.debug("UpOp") + operand = _eval(node.operand) + if type(node.op) not in un_ops: + msg = f"Unsupported operator: {type(node.op)}" + raise SyntaxError(msg) + return un_ops[type(node.op)](operand) + if isinstance(node, ast.Call): + if not isinstance(node.func, ast.Name): + msg = "Only direct math function calls allowed" + raise SyntaxError(msg) + args = [_eval(x) for x in node.args] + return checkmath(node.func.id, *args) + msg = f"Bad syntax, {type(node)}" + raise SyntaxError(msg) + + return _eval(tree) diff --git a/tests/test_safe_eval.py b/tests/test_safe_eval.py new file mode 100755 index 00000000..16359854 --- /dev/null +++ b/tests/test_safe_eval.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +""" +Tests for safe_eval.py. + +This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator + +SPDX-FileCopyrightText: 2024 Amilcar do Carmo Lucas + +SPDX-License-Identifier: GPL-3.0-or-later +""" + +import math + +import pytest + +from ardupilot_methodic_configurator.safe_eval import safe_eval + + +def test_basic_arithmetic() -> None: + """Test basic arithmetic operations.""" + assert safe_eval("1+1") == 2 + assert safe_eval("1+-5") == -4 + assert safe_eval("-1") == -1 + assert safe_eval("-+1") == -1 + assert safe_eval("(100*10)+6") == 1006 + assert safe_eval("100*(10+6)") == 1600 + assert safe_eval("2**4") == 16 + assert pytest.approx(safe_eval("1.2345 * 10")) == 12.345 + + +def test_math_functions() -> None: + """Test mathematical functions.""" + assert safe_eval("sqrt(16)+1") == 5 + assert safe_eval("sin(0)") == 0 + assert safe_eval("cos(0)") == 1 + assert safe_eval("tan(0)") == 0 + assert safe_eval("log(1)") == 0 + assert safe_eval("exp(0)") == 1 + assert safe_eval("pi") == math.pi + + +def test_complex_expressions() -> None: + """Test more complex mathematical expressions.""" + assert safe_eval("2 * (3 + 4)") == 14 + assert safe_eval("2 ** 3 * 4") == 32 + assert safe_eval("sqrt(16) + sqrt(9)") == 7 + assert safe_eval("sin(pi/2)") == 1 + + +def test_error_cases() -> None: + """Test error conditions.""" + with pytest.raises(SyntaxError): + safe_eval("1 + ") # Incomplete expression + + with pytest.raises(SyntaxError): + safe_eval("unknown_func(10)") # Unknown function + + with pytest.raises(SyntaxError): + safe_eval("1 = 1") # Invalid operator + + with pytest.raises(SyntaxError): + safe_eval("import os") # Attempted import + + +def test_nested_expressions() -> None: + """Test nested mathematical expressions.""" + assert safe_eval("sqrt(pow(3,2) + pow(4,2))") == 5 # Pythagorean theorem + assert safe_eval("log(exp(1))") == 1 + assert safe_eval("sin(pi/6)**2 + cos(pi/6)**2") == pytest.approx(1) + + +def test_division_by_zero() -> None: + """Test division by zero handling.""" + with pytest.raises(ZeroDivisionError): + safe_eval("1/0") + with pytest.raises(ZeroDivisionError): + safe_eval("10 % 0") + + +def test_invalid_math_functions() -> None: + """Test invalid math function calls.""" + with pytest.raises(SyntaxError, match=r".*takes exactly one argument.*"): + safe_eval("sin()") # Missing argument + with pytest.raises(SyntaxError, match=r".*takes exactly one argument.*"): + safe_eval("sin(1,2)") # Too many arguments + with pytest.raises(ValueError, match=r"math domain error"): + safe_eval("sqrt(-1)") # Domain error + with pytest.raises(ValueError, match=r"math domain error"): + safe_eval("log(-1)") # Range error + with pytest.raises(SyntaxError, match=r"Unknown func.*"): + safe_eval("unknown(1)") # Unknown function + + +def test_security() -> None: + """Test against code injection attempts.""" + with pytest.raises(SyntaxError): + safe_eval("__import__('os').system('ls')") + with pytest.raises(SyntaxError): + safe_eval("open('/etc/passwd')") + with pytest.raises(SyntaxError): + safe_eval("eval('1+1')") + + +def test_operator_precedence() -> None: + """Test operator precedence rules.""" + assert safe_eval("2 + 3 * 4") == 14 + assert safe_eval("(2 + 3) * 4") == 20 + assert safe_eval("-2 ** 2") == -4 # Exponentiation before negation + assert safe_eval("-(2 ** 2)") == -4 + + +def test_float_precision() -> None: + """Test floating point precision handling.""" + assert pytest.approx(safe_eval("0.1 + 0.2")) == 0.3 + assert pytest.approx(safe_eval("sin(pi/2)")) == 1.0 + assert pytest.approx(safe_eval("cos(pi)")) == -1.0 + assert pytest.approx(safe_eval("exp(log(2.718281828))")) == math.e + + +def test_math_constants() -> None: + """Test mathematical constants.""" + assert safe_eval("pi") == math.pi + assert safe_eval("e") == math.e + assert safe_eval("tau") == math.tau + assert safe_eval("inf") == math.inf + with pytest.raises(SyntaxError): + safe_eval("not_a_constant")