Skip to content

Commit 158cc42

Browse files
committed
Move binop handling logic to expr_pass, remove delayed imports of get_operand_value
1 parent 2a1eabc commit 158cc42

File tree

4 files changed

+117
-119
lines changed

4 files changed

+117
-119
lines changed

pythonbpf/binary_ops.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

pythonbpf/expr/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .expr_pass import eval_expr, handle_expr
1+
from .expr_pass import eval_expr, handle_expr, get_operand_value
22
from .type_normalization import convert_to_bool, get_base_type_and_depth, deref_to_depth
33

44
__all__ = [
@@ -7,4 +7,5 @@
77
"convert_to_bool",
88
"get_base_type_and_depth",
99
"deref_to_depth",
10+
"get_operand_value",
1011
]

pythonbpf/expr/expr_pass.py

Lines changed: 109 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,118 @@
55
from typing import Dict
66

77
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
8-
from .type_normalization import convert_to_bool, handle_comparator
8+
from .type_normalization import (
9+
convert_to_bool,
10+
handle_comparator,
11+
get_base_type_and_depth,
12+
deref_to_depth,
13+
)
914

1015
logger: Logger = logging.getLogger(__name__)
1116

1217

18+
def get_operand_value(
19+
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
20+
):
21+
"""Extract the value from an operand, handling variables and constants."""
22+
logger.info(f"Getting operand value for: {ast.dump(operand)}")
23+
if isinstance(operand, ast.Name):
24+
if operand.id in local_sym_tab:
25+
var = local_sym_tab[operand.id].var
26+
var_type = var.type
27+
base_type, depth = get_base_type_and_depth(var_type)
28+
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
29+
val = deref_to_depth(func, builder, var, depth)
30+
return val
31+
raise ValueError(f"Undefined variable: {operand.id}")
32+
elif isinstance(operand, ast.Constant):
33+
if isinstance(operand.value, int):
34+
cst = ir.Constant(ir.IntType(64), int(operand.value))
35+
return cst
36+
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
37+
elif isinstance(operand, ast.BinOp):
38+
res = _handle_binary_op_impl(
39+
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
40+
)
41+
return res
42+
else:
43+
res = eval_expr(
44+
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
45+
)
46+
if res is None:
47+
raise ValueError(f"Failed to evaluate call expression: {operand}")
48+
val, _ = res
49+
logger.info(f"Evaluated expr to {val} of type {val.type}")
50+
base_type, depth = get_base_type_and_depth(val.type)
51+
if depth > 0:
52+
val = deref_to_depth(func, builder, val, depth)
53+
return val
54+
raise TypeError(f"Unsupported operand type: {type(operand)}")
55+
56+
57+
def _handle_binary_op_impl(
58+
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
59+
):
60+
op = rval.op
61+
left = get_operand_value(
62+
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab
63+
)
64+
right = get_operand_value(
65+
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab
66+
)
67+
logger.info(f"left is {left}, right is {right}, op is {op}")
68+
69+
# NOTE: Before doing the operation, if the operands are integers
70+
# we always extend them to i64. The assignment to LHS will take
71+
# care of truncation if needed.
72+
if isinstance(left.type, ir.IntType) and left.type.width < 64:
73+
left = builder.sext(left, ir.IntType(64))
74+
if isinstance(right.type, ir.IntType) and right.type.width < 64:
75+
right = builder.sext(right, ir.IntType(64))
76+
77+
# Map AST operation nodes to LLVM IR builder methods
78+
op_map = {
79+
ast.Add: builder.add,
80+
ast.Sub: builder.sub,
81+
ast.Mult: builder.mul,
82+
ast.Div: builder.sdiv,
83+
ast.Mod: builder.srem,
84+
ast.LShift: builder.shl,
85+
ast.RShift: builder.lshr,
86+
ast.BitOr: builder.or_,
87+
ast.BitXor: builder.xor,
88+
ast.BitAnd: builder.and_,
89+
ast.FloorDiv: builder.udiv,
90+
}
91+
92+
if type(op) in op_map:
93+
result = op_map[type(op)](left, right)
94+
return result
95+
else:
96+
raise SyntaxError("Unsupported binary operation")
97+
98+
99+
def _handle_binary_op(
100+
func,
101+
module,
102+
rval,
103+
builder,
104+
var_name,
105+
local_sym_tab,
106+
map_sym_tab,
107+
structs_sym_tab=None,
108+
):
109+
result = _handle_binary_op_impl(
110+
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab
111+
)
112+
if var_name and var_name in local_sym_tab:
113+
logger.info(
114+
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
115+
)
116+
builder.store(result, local_sym_tab[var_name].var)
117+
return result, result.type
118+
119+
13120
def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder):
14121
"""Handle ast.Name expressions."""
15122
if expr.id in local_sym_tab:
@@ -194,8 +301,6 @@ def _handle_unary_op(
194301
logger.error("Only 'not' and '-' unary operators are supported")
195302
return None
196303

197-
from pythonbpf.binary_ops import get_operand_value
198-
199304
operand = get_operand_value(
200305
func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
201306
)
@@ -421,9 +526,7 @@ def eval_expr(
421526
elif isinstance(expr, ast.Attribute):
422527
return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
423528
elif isinstance(expr, ast.BinOp):
424-
from pythonbpf.binary_ops import handle_binary_op
425-
426-
return handle_binary_op(
529+
return _handle_binary_op(
427530
func,
428531
module,
429532
expr,

pythonbpf/helper/helper_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
from collections.abc import Callable
44

55
from llvmlite import ir
6-
from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth
7-
from pythonbpf.binary_ops import get_operand_value
6+
from pythonbpf.expr import (
7+
eval_expr,
8+
get_base_type_and_depth,
9+
deref_to_depth,
10+
get_operand_value,
11+
)
812

913
logger = logging.getLogger(__name__)
1014

0 commit comments

Comments
 (0)