Skip to content

Commit 585f6ec

Browse files
committed
feat: allow registering custom boolean operators
1 parent f77d117 commit 585f6ec

File tree

9 files changed

+315
-91
lines changed

9 files changed

+315
-91
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,28 @@ expr = "starts_with(user.name, 'Sn')"
5555
print(evaluate(expr, context={"user": {"name": "Snoopy"}})) # True
5656
```
5757

58+
### Custom operators
59+
60+
```py
61+
from boolia import evaluate, DEFAULT_OPERATORS
62+
63+
custom_ops = DEFAULT_OPERATORS.copy()
64+
custom_ops.register(
65+
"XOR", # The operator identifier
66+
precedence=20, # Higher precedence than AND/OR
67+
evaluator=lambda left, right: bool(left) ^ bool(right), # XOR logic
68+
keywords=("xor",), # Use "xor" in expressions
69+
)
70+
71+
print(evaluate("true xor false", operators=custom_ops)) # True
72+
print(evaluate("true xor true", operators=custom_ops)) # False
73+
```
74+
75+
Operators can be declared with `keywords=("xor",)` for word-style syntax or `symbols=("^",)`
76+
for symbolic tokens. Use `compile_rule(expr, operators=custom_ops)` to persist custom
77+
operators inside compiled rules. When evaluating rules or rule groups you can still pass a
78+
different registry with `operators=` if you need to override their behavior.
79+
5880
### RuleBook
5981

6082
```py

boolia/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .resolver import default_resolver_factory, MissingPolicy
33
from .errors import MissingVariableError
44
from .functions import FunctionRegistry, DEFAULT_FUNCTIONS
5+
from .operators import OperatorRegistry, DEFAULT_OPERATORS
56

67
__all__ = [
78
"evaluate",
@@ -15,4 +16,6 @@
1516
"MissingVariableError",
1617
"FunctionRegistry",
1718
"DEFAULT_FUNCTIONS",
19+
"OperatorRegistry",
20+
"DEFAULT_OPERATORS",
1821
]

boolia/api.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
from .ast import Node
77
from .resolver import default_resolver_factory, MissingPolicy
88
from .functions import FunctionRegistry, DEFAULT_FUNCTIONS
9+
from .operators import OperatorRegistry, DEFAULT_OPERATORS
910

1011

1112
RuleEntry = Union["Rule", "RuleGroup"]
1213
RuleMember = Union[str, RuleEntry]
1314

1415

15-
def compile_expr(source: str) -> Node:
16-
return parse(source)
16+
def compile_expr(source: str, *, operators: Optional[OperatorRegistry] = None) -> Node:
17+
return parse(source, operators)
1718

1819

1920
def evaluate(
@@ -25,26 +26,33 @@ def evaluate(
2526
on_missing: MissingPolicy = "false",
2627
default_value: Any = None,
2728
functions: Optional[FunctionRegistry] = None,
29+
operators: Optional[OperatorRegistry] = None,
2830
) -> bool:
29-
node = compile_expr(source_or_ast) if isinstance(source_or_ast, str) else source_or_ast
31+
ops = operators or DEFAULT_OPERATORS
32+
node = compile_expr(source_or_ast, operators=ops) if isinstance(source_or_ast, str) else source_or_ast
3033
ctx = context or {}
3134
tg = tags or set()
3235
res = resolver or default_resolver_factory(ctx, on_missing=on_missing, default_value=default_value)
3336
fns = functions or DEFAULT_FUNCTIONS
34-
out = node.eval(res, tg, fns)
37+
out = node.eval(res, tg, fns, ops)
3538
return bool(out)
3639

3740

3841
@dataclass
3942
class Rule:
4043
ast: Node
44+
operators: OperatorRegistry = DEFAULT_OPERATORS
4145

42-
def evaluate(self, **kwargs) -> bool:
43-
return evaluate(self.ast, **kwargs)
46+
def evaluate(self, *, operators: Optional[OperatorRegistry] = None, **kwargs) -> bool:
47+
local_kwargs = kwargs.copy()
48+
implicit_ops = local_kwargs.pop("operators", None)
49+
ops = operators or implicit_ops or self.operators
50+
return evaluate(self.ast, operators=ops, **local_kwargs)
4451

4552

46-
def compile_rule(source: str) -> Rule:
47-
return Rule(compile_expr(source))
53+
def compile_rule(source: str, *, operators: Optional[OperatorRegistry] = None) -> Rule:
54+
ops = operators or DEFAULT_OPERATORS
55+
return Rule(compile_expr(source, operators=ops), ops)
4856

4957

5058
class RuleGroup:

boolia/ast.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@
66

77

88
class Node:
9-
def eval(self, resolve: Resolver, tags: Set[str], functions) -> Any:
9+
def eval(self, resolve: Resolver, tags: Set[str], functions, operators) -> Any:
1010
raise NotImplementedError()
1111

1212

1313
@dataclass
1414
class Literal(Node):
1515
value: Any
1616

17-
def eval(self, resolve, tags, functions):
17+
def eval(self, resolve, tags, functions, operators):
1818
return self.value
1919

2020

2121
@dataclass
2222
class Name(Node):
2323
parts: List[str] # e.g., ["house","light","on"] or ["car"]
2424

25-
def eval(self, resolve, tags, functions):
25+
def eval(self, resolve, tags, functions, operators):
2626
if len(self.parts) == 1:
2727
name = self.parts[0]
2828
val = resolve(self.parts)
@@ -39,8 +39,8 @@ class Unary(Node):
3939
op: str
4040
right: Node
4141

42-
def eval(self, resolve, tags, functions):
43-
v = self.right.eval(resolve, tags, functions)
42+
def eval(self, resolve, tags, functions, operators):
43+
v = self.right.eval(resolve, tags, functions, operators)
4444
if self.op == "NOT":
4545
return not bool(v)
4646
raise ValueError(f"Unknown unary op {self.op}")
@@ -52,39 +52,18 @@ class Binary(Node):
5252
op: str
5353
right: Node
5454

55-
def eval(self, resolve, tags, functions):
56-
if self.op == "AND":
57-
return bool(self.left.eval(resolve, tags, functions)) and bool(self.right.eval(resolve, tags, functions))
58-
if self.op == "OR":
59-
return bool(self.left.eval(resolve, tags, functions)) or bool(self.right.eval(resolve, tags, functions))
60-
le = self.left.eval(resolve, tags, functions)
61-
r = self.right.eval(resolve, tags, functions)
62-
if self.op == "EQ":
63-
return le == r
64-
if self.op == "NE":
65-
return le != r
66-
if self.op == "GT":
67-
return le > r
68-
if self.op == "LT":
69-
return le < r
70-
if self.op == "GE":
71-
return le >= r
72-
if self.op == "LE":
73-
return le <= r
74-
if self.op == "IN":
75-
try:
76-
return le in r
77-
except TypeError:
78-
return False
79-
raise ValueError(f"Unknown binary op {self.op}")
55+
def eval(self, resolve, tags, functions, operators):
56+
left_val = self.left.eval(resolve, tags, functions, operators)
57+
right_val = self.right.eval(resolve, tags, functions, operators)
58+
return operators.evaluate(self.op, left_val, right_val)
8059

8160

8261
@dataclass
8362
class Call(Node):
8463
name: str
8564
args: List[Node]
8665

87-
def eval(self, resolve, tags, functions):
66+
def eval(self, resolve, tags, functions, operators):
8867
fn = functions.get(self.name)
89-
vals = [a.eval(resolve, tags, functions) for a in self.args]
68+
vals = [a.eval(resolve, tags, functions, operators) for a in self.args]
9069
return fn(*vals)

boolia/lexer.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,34 @@
1-
from typing import List, Tuple, Any
1+
from typing import List, Tuple, Any, Optional
2+
3+
from .operators import OperatorRegistry, DEFAULT_OPERATORS
24

35
Token = Tuple[str, Any] # (type, value)
46

5-
KEYWORDS = {"and", "or", "not", "in", "true", "false", "null", "none"}
6-
SYMS = {
7+
_FIXED_KEYWORDS = {
8+
"true": ("BOOL", True),
9+
"false": ("BOOL", False),
10+
"null": ("NULL", None),
11+
"none": ("NULL", None),
12+
"not": ("NOT", "not"),
13+
}
14+
15+
_BASE_SYMBOLS = {
716
"(": "LPAREN",
817
")": "RPAREN",
918
".": "DOT",
1019
",": "COMMA",
11-
"==": "EQ",
12-
"!=": "NE",
13-
">=": "GE",
14-
"<=": "LE",
15-
">": "GT",
16-
"<": "LT",
1720
}
1821

1922

20-
def tokenize(s: str) -> List[Token]:
23+
def tokenize(s: str, operators: Optional[OperatorRegistry] = None) -> List[Token]:
2124
import re
2225

26+
ops = operators or DEFAULT_OPERATORS
27+
keyword_tokens = ops.keyword_tokens()
28+
symbol_tokens = _BASE_SYMBOLS.copy()
29+
symbol_tokens.update(ops.symbol_tokens())
30+
max_symbol_len = max((len(sym) for sym in symbol_tokens), default=0)
31+
2332
tokens: List[Token] = []
2433
i = 0
2534
n = len(s)
@@ -35,14 +44,18 @@ def tokenize(s: str) -> List[Token]:
3544
if i >= n:
3645
break
3746

38-
if i + 1 < n and s[i : i + 2] in SYMS:
39-
tokens.append((SYMS[s[i : i + 2]], s[i : i + 2]))
40-
i += 2
41-
continue
42-
43-
if s[i] in SYMS:
44-
tokens.append((SYMS[s[i]], s[i]))
45-
i += 1
47+
matched_symbol = False
48+
for length in range(max_symbol_len, 0, -1):
49+
if i + length > n:
50+
continue
51+
frag = s[i : i + length]
52+
token_type = symbol_tokens.get(frag)
53+
if token_type is not None:
54+
tokens.append((token_type, frag))
55+
i += length
56+
matched_symbol = True
57+
break
58+
if matched_symbol:
4659
continue
4760

4861
m = string.match(s, i)
@@ -65,17 +78,18 @@ def tokenize(s: str) -> List[Token]:
6578
if m:
6679
name = m.group(0)
6780
low = name.lower()
68-
if low in KEYWORDS:
69-
if low == "true":
70-
tokens.append(("BOOL", True))
71-
elif low == "false":
72-
tokens.append(("BOOL", False))
73-
elif low in ("null", "none"):
74-
tokens.append(("NULL", None))
75-
else:
76-
tokens.append((low.upper(), low)) # AND, OR, NOT, IN
77-
else:
78-
tokens.append(("IDENT", name))
81+
fixed = _FIXED_KEYWORDS.get(low)
82+
if fixed is not None:
83+
tokens.append(fixed)
84+
i = m.end()
85+
continue
86+
op_token = keyword_tokens.get(low)
87+
if op_token is not None:
88+
tokens.append((op_token, low))
89+
i = m.end()
90+
continue
91+
92+
tokens.append(("IDENT", name))
7993
i = m.end()
8094
continue
8195

0 commit comments

Comments
 (0)