Skip to content

Commit 284b22f

Browse files
committed
feat(param_logic): use a safer eval code, with limited arithmetic capabilities.
1 parent 14b5ae9 commit 284b22f

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator.
3+
4+
SPDX-FileCopyrightText: 2024 Amilcar do Carmo Lucas <[email protected]>
5+
6+
SPDX-License-Identifier: GPL-3.0-or-later
7+
"""
8+
9+
import ast
10+
import logging
11+
import math
12+
import operator
13+
from typing import Callable, Union, cast
14+
15+
logger = logging.getLogger(__name__)
16+
17+
# Type aliases
18+
Number = Union[int, float]
19+
MathFunc = Callable[..., Number]
20+
BinOperator = Callable[[Number, Number], Number]
21+
UnOperator = Callable[[Number], Number]
22+
23+
def safe_eval(s: str) -> Number:
24+
def checkmath(x: str, *args: Number) -> Number:
25+
if x not in [x for x in dir(math) if "__" not in x]:
26+
msg = f"Unknown func {x}()"
27+
raise SyntaxError(msg)
28+
fun = cast(MathFunc, getattr(math, x))
29+
try:
30+
return fun(*args)
31+
except TypeError as e:
32+
msg = f"Invalid arguments for {x}(): {e!s}"
33+
raise SyntaxError(msg) from e
34+
35+
bin_ops: dict[type[ast.operator], BinOperator] = {
36+
ast.Add: operator.add,
37+
ast.Sub: operator.sub,
38+
ast.Mult: operator.mul,
39+
ast.Div: operator.truediv,
40+
ast.Mod: operator.mod,
41+
ast.Pow: operator.pow,
42+
ast.Call: checkmath,
43+
ast.BinOp: ast.BinOp,
44+
}
45+
46+
un_ops: dict[type[ast.UnaryOp], UnOperator] = {
47+
ast.USub: operator.neg,
48+
ast.UAdd: operator.pos,
49+
ast.UnaryOp: ast.UnaryOp,
50+
}
51+
52+
tree = ast.parse(s, mode="eval")
53+
54+
def _eval(node: ast.AST) -> Number:
55+
if isinstance(node, ast.Expression):
56+
logger.debug("Expr")
57+
return _eval(node.body)
58+
if isinstance(node, ast.Constant):
59+
logger.info("Const")
60+
return cast(Number, node.value)
61+
if isinstance(node, ast.Name):
62+
# Handle math constants like pi, e, etc.
63+
logger.info("MathConst")
64+
if hasattr(math, node.id):
65+
return cast(Number, getattr(math, node.id))
66+
msg = f"Unknown constant: {node.id}"
67+
raise SyntaxError(msg)
68+
if isinstance(node, ast.BinOp):
69+
logger.debug("BinOp")
70+
left = _eval(node.left)
71+
right = _eval(node.right)
72+
if type(node.op) not in bin_ops:
73+
msg = f"Unsupported operator: {type(node.op)}"
74+
raise SyntaxError(msg)
75+
return bin_ops[type(node.op)](left, right)
76+
if isinstance(node, ast.UnaryOp):
77+
logger.debug("UpOp")
78+
operand = _eval(node.operand)
79+
if type(node.op) not in un_ops:
80+
msg = f"Unsupported operator: {type(node.op)}"
81+
raise SyntaxError(msg)
82+
return un_ops[type(node.op)](operand)
83+
if isinstance(node, ast.Call):
84+
if not isinstance(node.func, ast.Name):
85+
msg = "Only direct math function calls allowed"
86+
raise SyntaxError(msg)
87+
args = [_eval(x) for x in node.args]
88+
return checkmath(node.func.id, *args)
89+
msg = f"Bad syntax, {type(node)}"
90+
raise SyntaxError(msg)
91+
92+
return _eval(tree)

tests/test_safe_eval.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Tests for safe_eval.py.
4+
5+
This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator
6+
7+
SPDX-FileCopyrightText: 2024 Amilcar do Carmo Lucas <[email protected]>
8+
9+
SPDX-License-Identifier: GPL-3.0-or-later
10+
"""
11+
12+
import math
13+
14+
import pytest
15+
16+
from ardupilot_methodic_configurator.safe_eval import safe_eval
17+
18+
19+
def test_basic_arithmetic() -> None:
20+
"""Test basic arithmetic operations."""
21+
assert safe_eval("1+1") == 2
22+
assert safe_eval("1+-5") == -4
23+
assert safe_eval("-1") == -1
24+
assert safe_eval("-+1") == -1
25+
assert safe_eval("(100*10)+6") == 1006
26+
assert safe_eval("100*(10+6)") == 1600
27+
assert safe_eval("2**4") == 16
28+
assert pytest.approx(safe_eval("1.2345 * 10")) == 12.345
29+
30+
31+
def test_math_functions() -> None:
32+
"""Test mathematical functions."""
33+
assert safe_eval("sqrt(16)+1") == 5
34+
assert safe_eval("sin(0)") == 0
35+
assert safe_eval("cos(0)") == 1
36+
assert safe_eval("tan(0)") == 0
37+
assert safe_eval("log(1)") == 0
38+
assert safe_eval("exp(0)") == 1
39+
assert safe_eval("pi") == math.pi
40+
41+
42+
def test_complex_expressions() -> None:
43+
"""Test more complex mathematical expressions."""
44+
assert safe_eval("2 * (3 + 4)") == 14
45+
assert safe_eval("2 ** 3 * 4") == 32
46+
assert safe_eval("sqrt(16) + sqrt(9)") == 7
47+
assert safe_eval("sin(pi/2)") == 1
48+
49+
50+
def test_error_cases() -> None:
51+
"""Test error conditions."""
52+
with pytest.raises(SyntaxError):
53+
safe_eval("1 + ") # Incomplete expression
54+
55+
with pytest.raises(SyntaxError):
56+
safe_eval("unknown_func(10)") # Unknown function
57+
58+
with pytest.raises(SyntaxError):
59+
safe_eval("1 = 1") # Invalid operator
60+
61+
with pytest.raises(SyntaxError):
62+
safe_eval("import os") # Attempted import
63+
64+
65+
def test_nested_expressions() -> None:
66+
"""Test nested mathematical expressions."""
67+
assert safe_eval("sqrt(pow(3,2) + pow(4,2))") == 5 # Pythagorean theorem
68+
assert safe_eval("log(exp(1))") == 1
69+
assert safe_eval("sin(pi/6)**2 + cos(pi/6)**2") == pytest.approx(1)
70+
71+
72+
def test_division_by_zero() -> None:
73+
"""Test division by zero handling."""
74+
with pytest.raises(ZeroDivisionError):
75+
safe_eval("1/0")
76+
with pytest.raises(ZeroDivisionError):
77+
safe_eval("10 % 0")
78+
79+
80+
def test_invalid_math_functions() -> None:
81+
"""Test invalid math function calls."""
82+
with pytest.raises(SyntaxError, match=r".*takes exactly one argument.*"):
83+
safe_eval("sin()") # Missing argument
84+
with pytest.raises(SyntaxError, match=r".*takes exactly one argument.*"):
85+
safe_eval("sin(1,2)") # Too many arguments
86+
with pytest.raises(ValueError, match=r"math domain error"):
87+
safe_eval("sqrt(-1)") # Domain error
88+
with pytest.raises(ValueError, match=r"math domain error"):
89+
safe_eval("log(-1)") # Range error
90+
with pytest.raises(SyntaxError, match=r"Unknown func.*"):
91+
safe_eval("unknown(1)") # Unknown function
92+
93+
94+
def test_security() -> None:
95+
"""Test against code injection attempts."""
96+
with pytest.raises(SyntaxError):
97+
safe_eval("__import__('os').system('ls')")
98+
with pytest.raises(SyntaxError):
99+
safe_eval("open('/etc/passwd')")
100+
with pytest.raises(SyntaxError):
101+
safe_eval("eval('1+1')")
102+
103+
104+
def test_operator_precedence() -> None:
105+
"""Test operator precedence rules."""
106+
assert safe_eval("2 + 3 * 4") == 14
107+
assert safe_eval("(2 + 3) * 4") == 20
108+
assert safe_eval("-2 ** 2") == -4 # Exponentiation before negation
109+
assert safe_eval("-(2 ** 2)") == -4
110+
111+
112+
def test_float_precision() -> None:
113+
"""Test floating point precision handling."""
114+
assert pytest.approx(safe_eval("0.1 + 0.2")) == 0.3
115+
assert pytest.approx(safe_eval("sin(pi/2)")) == 1.0
116+
assert pytest.approx(safe_eval("cos(pi)")) == -1.0
117+
assert pytest.approx(safe_eval("exp(log(2.718281828))")) == math.e
118+
119+
120+
def test_math_constants() -> None:
121+
"""Test mathematical constants."""
122+
assert safe_eval("pi") == math.pi
123+
assert safe_eval("e") == math.e
124+
assert safe_eval("tau") == math.tau
125+
assert safe_eval("inf") == math.inf
126+
with pytest.raises(SyntaxError):
127+
safe_eval("not_a_constant")

0 commit comments

Comments
 (0)