Skip to content

Commit 7a67041

Browse files
committed
Move CallHandlerRegistry to expr/call_registry.py, annotate eval_expr
1 parent 45e6ce5 commit 7a67041

File tree

3 files changed

+130
-106
lines changed

3 files changed

+130
-106
lines changed

pythonbpf/expr/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from .expr_pass import eval_expr, handle_expr, get_operand_value, CallHandlerRegistry
1+
from .expr_pass import eval_expr, handle_expr, get_operand_value
22
from .type_normalization import convert_to_bool, get_base_type_and_depth
33
from .ir_ops import deref_to_depth
4+
from .call_registry import CallHandlerRegistry
45

56
__all__ = [
67
"eval_expr",

pythonbpf/expr/call_registry.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
class CallHandlerRegistry:
2+
"""Registry for handling different types of calls (helpers, etc.)"""
3+
4+
_handler = None
5+
6+
@classmethod
7+
def set_handler(cls, handler):
8+
"""Set the handler for unknown calls"""
9+
cls._handler = handler
10+
11+
@classmethod
12+
def handle_call(
13+
cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
14+
):
15+
"""Handle a call using the registered handler"""
16+
if cls._handler is None:
17+
return None
18+
return cls._handler(
19+
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
20+
)

pythonbpf/expr/expr_pass.py

Lines changed: 108 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict
66

77
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
8+
from .call_registry import CallHandlerRegistry
89
from .type_normalization import (
910
convert_to_bool,
1011
handle_comparator,
@@ -14,27 +15,106 @@
1415

1516
logger: Logger = logging.getLogger(__name__)
1617

18+
# ============================================================================
19+
# Leaf Handlers (No Recursive eval_expr calls)
20+
# ============================================================================
1721

18-
class CallHandlerRegistry:
19-
"""Registry for handling different types of calls (helpers, etc.)"""
2022

21-
_handler = None
23+
def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder):
24+
"""Handle ast.Name expressions."""
25+
if expr.id in local_sym_tab:
26+
var = local_sym_tab[expr.id].var
27+
val = builder.load(var)
28+
return val, local_sym_tab[expr.id].ir_type
29+
else:
30+
logger.info(f"Undefined variable {expr.id}")
31+
return None
32+
33+
34+
def _handle_constant_expr(module, builder, expr: ast.Constant):
35+
"""Handle ast.Constant expressions."""
36+
if isinstance(expr.value, int) or isinstance(expr.value, bool):
37+
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
38+
elif isinstance(expr.value, str):
39+
str_name = f".str.{id(expr)}"
40+
str_bytes = expr.value.encode("utf-8") + b"\x00"
41+
str_type = ir.ArrayType(ir.IntType(8), len(str_bytes))
42+
str_constant = ir.Constant(str_type, bytearray(str_bytes))
43+
44+
# Create global variable
45+
global_str = ir.GlobalVariable(module, str_type, name=str_name)
46+
global_str.linkage = "internal"
47+
global_str.global_constant = True
48+
global_str.initializer = str_constant
49+
50+
str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
51+
return str_ptr, ir.PointerType(ir.IntType(8))
52+
else:
53+
logger.error(f"Unsupported constant type {ast.dump(expr)}")
54+
return None
55+
56+
57+
def _handle_attribute_expr(
58+
expr: ast.Attribute,
59+
local_sym_tab: Dict,
60+
structs_sym_tab: Dict,
61+
builder: ir.IRBuilder,
62+
):
63+
"""Handle ast.Attribute expressions for struct field access."""
64+
if isinstance(expr.value, ast.Name):
65+
var_name = expr.value.id
66+
attr_name = expr.attr
67+
if var_name in local_sym_tab:
68+
var_ptr, var_type, var_metadata = local_sym_tab[var_name]
69+
logger.info(f"Loading attribute {attr_name} from variable {var_name}")
70+
logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
71+
metadata = structs_sym_tab[var_metadata]
72+
if attr_name in metadata.fields:
73+
gep = metadata.gep(builder, var_ptr, attr_name)
74+
val = builder.load(gep)
75+
field_type = metadata.field_type(attr_name)
76+
return val, field_type
77+
return None
2278

23-
@classmethod
24-
def set_handler(cls, handler):
25-
"""Set the handler for unknown calls"""
26-
cls._handler = handler
2779

28-
@classmethod
29-
def handle_call(
30-
cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
80+
def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder):
81+
"""Handle deref function calls."""
82+
logger.info(f"Handling deref {ast.dump(expr)}")
83+
if len(expr.args) != 1:
84+
logger.info("deref takes exactly one argument")
85+
return None
86+
87+
arg = expr.args[0]
88+
if (
89+
isinstance(arg, ast.Call)
90+
and isinstance(arg.func, ast.Name)
91+
and arg.func.id == "deref"
3192
):
32-
"""Handle a call using the registered handler"""
33-
if cls._handler is None:
93+
logger.info("Multiple deref not supported")
94+
return None
95+
96+
if isinstance(arg, ast.Name):
97+
if arg.id in local_sym_tab:
98+
arg_ptr = local_sym_tab[arg.id].var
99+
else:
100+
logger.info(f"Undefined variable {arg.id}")
34101
return None
35-
return cls._handler(
36-
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
37-
)
102+
else:
103+
logger.info("Unsupported argument type for deref")
104+
return None
105+
106+
if arg_ptr is None:
107+
logger.info("Failed to evaluate deref argument")
108+
return None
109+
110+
# Load the value from pointer
111+
val = builder.load(arg_ptr)
112+
return val, local_sym_tab[arg.id].ir_type
113+
114+
115+
# ============================================================================
116+
# Binary Operations
117+
# ============================================================================
38118

39119

40120
def get_operand_value(
@@ -139,96 +219,9 @@ def _handle_binary_op(
139219
return result, result.type
140220

141221

142-
def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder):
143-
"""Handle ast.Name expressions."""
144-
if expr.id in local_sym_tab:
145-
var = local_sym_tab[expr.id].var
146-
val = builder.load(var)
147-
return val, local_sym_tab[expr.id].ir_type
148-
else:
149-
logger.info(f"Undefined variable {expr.id}")
150-
return None
151-
152-
153-
def _handle_constant_expr(module, builder, expr: ast.Constant):
154-
"""Handle ast.Constant expressions."""
155-
if isinstance(expr.value, int) or isinstance(expr.value, bool):
156-
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
157-
elif isinstance(expr.value, str):
158-
str_name = f".str.{id(expr)}"
159-
str_bytes = expr.value.encode("utf-8") + b"\x00"
160-
str_type = ir.ArrayType(ir.IntType(8), len(str_bytes))
161-
str_constant = ir.Constant(str_type, bytearray(str_bytes))
162-
163-
# Create global variable
164-
global_str = ir.GlobalVariable(module, str_type, name=str_name)
165-
global_str.linkage = "internal"
166-
global_str.global_constant = True
167-
global_str.initializer = str_constant
168-
169-
str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
170-
return str_ptr, ir.PointerType(ir.IntType(8))
171-
else:
172-
logger.error(f"Unsupported constant type {ast.dump(expr)}")
173-
return None
174-
175-
176-
def _handle_attribute_expr(
177-
expr: ast.Attribute,
178-
local_sym_tab: Dict,
179-
structs_sym_tab: Dict,
180-
builder: ir.IRBuilder,
181-
):
182-
"""Handle ast.Attribute expressions for struct field access."""
183-
if isinstance(expr.value, ast.Name):
184-
var_name = expr.value.id
185-
attr_name = expr.attr
186-
if var_name in local_sym_tab:
187-
var_ptr, var_type, var_metadata = local_sym_tab[var_name]
188-
logger.info(f"Loading attribute {attr_name} from variable {var_name}")
189-
logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
190-
metadata = structs_sym_tab[var_metadata]
191-
if attr_name in metadata.fields:
192-
gep = metadata.gep(builder, var_ptr, attr_name)
193-
val = builder.load(gep)
194-
field_type = metadata.field_type(attr_name)
195-
return val, field_type
196-
return None
197-
198-
199-
def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder):
200-
"""Handle deref function calls."""
201-
logger.info(f"Handling deref {ast.dump(expr)}")
202-
if len(expr.args) != 1:
203-
logger.info("deref takes exactly one argument")
204-
return None
205-
206-
arg = expr.args[0]
207-
if (
208-
isinstance(arg, ast.Call)
209-
and isinstance(arg.func, ast.Name)
210-
and arg.func.id == "deref"
211-
):
212-
logger.info("Multiple deref not supported")
213-
return None
214-
215-
if isinstance(arg, ast.Name):
216-
if arg.id in local_sym_tab:
217-
arg_ptr = local_sym_tab[arg.id].var
218-
else:
219-
logger.info(f"Undefined variable {arg.id}")
220-
return None
221-
else:
222-
logger.info("Unsupported argument type for deref")
223-
return None
224-
225-
if arg_ptr is None:
226-
logger.info("Failed to evaluate deref argument")
227-
return None
228-
229-
# Load the value from pointer
230-
val = builder.load(arg_ptr)
231-
return val, local_sym_tab[arg.id].ir_type
222+
# ============================================================================
223+
# Comparison and Unary Operations
224+
# ============================================================================
232225

233226

234227
def _handle_ctypes_call(
@@ -341,6 +334,11 @@ def _handle_unary_op(
341334
return result, ir.IntType(64)
342335

343336

337+
# ============================================================================
338+
# Boolean Operations
339+
# ============================================================================
340+
341+
344342
def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
345343
"""Handle `and` boolean operations."""
346344

@@ -471,6 +469,11 @@ def _handle_boolean_op(
471469
return None
472470

473471

472+
# ============================================================================
473+
# Expression Dispatcher
474+
# ============================================================================
475+
476+
474477
def eval_expr(
475478
func,
476479
module,

0 commit comments

Comments
 (0)