Skip to content

Commit 8485460

Browse files
authored
Merge pull request #26 from pythonbpf/refactor_conds
Refactor conds
2 parents 0c97751 + 9fdc6fa commit 8485460

File tree

22 files changed

+776
-76
lines changed

22 files changed

+776
-76
lines changed

pythonbpf/expr/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .expr_pass import eval_expr, handle_expr
2+
from .type_normalization import convert_to_bool
3+
4+
__all__ = ["eval_expr", "handle_expr", "convert_to_bool"]

pythonbpf/expr_pass.py renamed to pythonbpf/expr/expr_pass.py

Lines changed: 210 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import logging
55
from typing import Dict
66

7-
from .type_deducer import ctypes_to_ir, is_ctypes
7+
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
8+
from .type_normalization import convert_to_bool, handle_comparator
89

910
logger: Logger = logging.getLogger(__name__)
1011

@@ -22,12 +23,10 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder
2223

2324
def _handle_constant_expr(expr: ast.Constant):
2425
"""Handle ast.Constant expressions."""
25-
if isinstance(expr.value, int):
26-
return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64)
27-
elif isinstance(expr.value, bool):
28-
return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1)
26+
if isinstance(expr.value, int) or isinstance(expr.value, bool):
27+
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
2928
else:
30-
logger.info("Unsupported constant type")
29+
logger.error("Unsupported constant type")
3130
return None
3231

3332

@@ -45,7 +44,6 @@ def _handle_attribute_expr(
4544
var_ptr, var_type, var_metadata = local_sym_tab[var_name]
4645
logger.info(f"Loading attribute {attr_name} from variable {var_name}")
4746
logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
48-
4947
metadata = structs_sym_tab[var_metadata]
5048
if attr_name in metadata.fields:
5149
gep = metadata.gep(builder, var_ptr, attr_name)
@@ -132,6 +130,199 @@ def _handle_ctypes_call(
132130
return val
133131

134132

133+
def _handle_compare(
134+
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
135+
):
136+
"""Handle ast.Compare expressions."""
137+
138+
if len(cond.ops) != 1 or len(cond.comparators) != 1:
139+
logger.error("Only single comparisons are supported")
140+
return None
141+
lhs = eval_expr(
142+
func,
143+
module,
144+
builder,
145+
cond.left,
146+
local_sym_tab,
147+
map_sym_tab,
148+
structs_sym_tab,
149+
)
150+
rhs = eval_expr(
151+
func,
152+
module,
153+
builder,
154+
cond.comparators[0],
155+
local_sym_tab,
156+
map_sym_tab,
157+
structs_sym_tab,
158+
)
159+
160+
if lhs is None or rhs is None:
161+
logger.error("Failed to evaluate comparison operands")
162+
return None
163+
164+
lhs, _ = lhs
165+
rhs, _ = rhs
166+
return handle_comparator(func, builder, cond.ops[0], lhs, rhs)
167+
168+
169+
def _handle_unary_op(
170+
func,
171+
module,
172+
builder,
173+
expr: ast.UnaryOp,
174+
local_sym_tab,
175+
map_sym_tab,
176+
structs_sym_tab=None,
177+
):
178+
"""Handle ast.UnaryOp expressions."""
179+
if not isinstance(expr.op, ast.Not):
180+
logger.error("Only 'not' unary operator is supported")
181+
return None
182+
183+
operand = eval_expr(
184+
func, module, builder, expr.operand, local_sym_tab, map_sym_tab, structs_sym_tab
185+
)
186+
if operand is None:
187+
logger.error("Failed to evaluate operand for unary operation")
188+
return None
189+
190+
operand_val, operand_type = operand
191+
true_const = ir.Constant(ir.IntType(1), 1)
192+
result = builder.xor(convert_to_bool(builder, operand_val), true_const)
193+
return result, ir.IntType(1)
194+
195+
196+
def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
197+
"""Handle `and` boolean operations."""
198+
199+
logger.debug(f"Handling 'and' operator with {len(expr.values)} operands")
200+
201+
merge_block = func.append_basic_block(name="and.merge")
202+
false_block = func.append_basic_block(name="and.false")
203+
204+
incoming_values = []
205+
206+
for i, value in enumerate(expr.values):
207+
is_last = i == len(expr.values) - 1
208+
209+
# Evaluate current operand
210+
operand_result = eval_expr(
211+
func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab
212+
)
213+
if operand_result is None:
214+
logger.error(f"Failed to evaluate operand {i} in 'and' expression")
215+
return None
216+
217+
operand_val, operand_type = operand_result
218+
219+
# Convert to boolean if needed
220+
operand_bool = convert_to_bool(builder, operand_val)
221+
current_block = builder.block
222+
223+
if is_last:
224+
# Last operand: result is this value
225+
builder.branch(merge_block)
226+
incoming_values.append((operand_bool, current_block))
227+
else:
228+
# Not last: check if true, continue or short-circuit
229+
next_check = func.append_basic_block(name=f"and.check_{i + 1}")
230+
builder.cbranch(operand_bool, next_check, false_block)
231+
builder.position_at_end(next_check)
232+
233+
# False block: short-circuit with false
234+
builder.position_at_end(false_block)
235+
builder.branch(merge_block)
236+
false_value = ir.Constant(ir.IntType(1), 0)
237+
incoming_values.append((false_value, false_block))
238+
239+
# Merge block: phi node
240+
builder.position_at_end(merge_block)
241+
phi = builder.phi(ir.IntType(1), name="and.result")
242+
for val, block in incoming_values:
243+
phi.add_incoming(val, block)
244+
245+
logger.debug(f"Generated 'and' with {len(incoming_values)} incoming values")
246+
return phi, ir.IntType(1)
247+
248+
249+
def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
250+
"""Handle `or` boolean operations."""
251+
252+
logger.debug(f"Handling 'or' operator with {len(expr.values)} operands")
253+
254+
merge_block = func.append_basic_block(name="or.merge")
255+
true_block = func.append_basic_block(name="or.true")
256+
257+
incoming_values = []
258+
259+
for i, value in enumerate(expr.values):
260+
is_last = i == len(expr.values) - 1
261+
262+
# Evaluate current operand
263+
operand_result = eval_expr(
264+
func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab
265+
)
266+
if operand_result is None:
267+
logger.error(f"Failed to evaluate operand {i} in 'or' expression")
268+
return None
269+
270+
operand_val, operand_type = operand_result
271+
272+
# Convert to boolean if needed
273+
operand_bool = convert_to_bool(builder, operand_val)
274+
current_block = builder.block
275+
276+
if is_last:
277+
# Last operand: result is this value
278+
builder.branch(merge_block)
279+
incoming_values.append((operand_bool, current_block))
280+
else:
281+
# Not last: check if false, continue or short-circuit
282+
next_check = func.append_basic_block(name=f"or.check_{i + 1}")
283+
builder.cbranch(operand_bool, true_block, next_check)
284+
builder.position_at_end(next_check)
285+
286+
# True block: short-circuit with true
287+
builder.position_at_end(true_block)
288+
builder.branch(merge_block)
289+
true_value = ir.Constant(ir.IntType(1), 1)
290+
incoming_values.append((true_value, true_block))
291+
292+
# Merge block: phi node
293+
builder.position_at_end(merge_block)
294+
phi = builder.phi(ir.IntType(1), name="or.result")
295+
for val, block in incoming_values:
296+
phi.add_incoming(val, block)
297+
298+
logger.debug(f"Generated 'or' with {len(incoming_values)} incoming values")
299+
return phi, ir.IntType(1)
300+
301+
302+
def _handle_boolean_op(
303+
func,
304+
module,
305+
builder,
306+
expr: ast.BoolOp,
307+
local_sym_tab,
308+
map_sym_tab,
309+
structs_sym_tab=None,
310+
):
311+
"""Handle `and` and `or` boolean operations."""
312+
313+
if isinstance(expr.op, ast.And):
314+
return _handle_and_op(
315+
func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
316+
)
317+
elif isinstance(expr.op, ast.Or):
318+
return _handle_or_op(
319+
func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
320+
)
321+
else:
322+
logger.error(f"Unsupported boolean operator: {type(expr.op).__name__}")
323+
return None
324+
325+
135326
def eval_expr(
136327
func,
137328
module,
@@ -212,6 +403,18 @@ def eval_expr(
212403
from pythonbpf.binary_ops import handle_binary_op
213404

214405
return handle_binary_op(expr, builder, None, local_sym_tab)
406+
elif isinstance(expr, ast.Compare):
407+
return _handle_compare(
408+
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
409+
)
410+
elif isinstance(expr, ast.UnaryOp):
411+
return _handle_unary_op(
412+
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
413+
)
414+
elif isinstance(expr, ast.BoolOp):
415+
return _handle_boolean_op(
416+
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
417+
)
215418
logger.info("Unsupported expression evaluation")
216419
return None
217420

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from llvmlite import ir
2+
import logging
3+
import ast
4+
5+
logger = logging.getLogger(__name__)
6+
7+
COMPARISON_OPS = {
8+
ast.Eq: "==",
9+
ast.NotEq: "!=",
10+
ast.Lt: "<",
11+
ast.LtE: "<=",
12+
ast.Gt: ">",
13+
ast.GtE: ">=",
14+
ast.Is: "==",
15+
ast.IsNot: "!=",
16+
}
17+
18+
19+
def _get_base_type_and_depth(ir_type):
20+
"""Get the base type for pointer types."""
21+
cur_type = ir_type
22+
depth = 0
23+
while isinstance(cur_type, ir.PointerType):
24+
depth += 1
25+
cur_type = cur_type.pointee
26+
return cur_type, depth
27+
28+
29+
def _deref_to_depth(func, builder, val, target_depth):
30+
"""Dereference a pointer to a certain depth."""
31+
32+
cur_val = val
33+
cur_type = val.type
34+
35+
for depth in range(target_depth):
36+
if not isinstance(val.type, ir.PointerType):
37+
logger.error("Cannot dereference further, non-pointer type")
38+
return None
39+
40+
# dereference with null check
41+
pointee_type = cur_type.pointee
42+
null_check_block = builder.block
43+
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
44+
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
45+
46+
null_ptr = ir.Constant(cur_type, None)
47+
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
48+
logger.debug(f"Inserted null check for pointer at depth {depth}")
49+
50+
builder.cbranch(is_not_null, not_null_block, merge_block)
51+
52+
builder.position_at_end(not_null_block)
53+
dereferenced_val = builder.load(cur_val)
54+
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
55+
builder.branch(merge_block)
56+
57+
builder.position_at_end(merge_block)
58+
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
59+
60+
zero_value = (
61+
ir.Constant(pointee_type, 0)
62+
if isinstance(pointee_type, ir.IntType)
63+
else ir.Constant(pointee_type, None)
64+
)
65+
phi.add_incoming(zero_value, null_check_block)
66+
67+
phi.add_incoming(dereferenced_val, not_null_block)
68+
69+
# Continue with phi result
70+
cur_val = phi
71+
cur_type = pointee_type
72+
return cur_val
73+
74+
75+
def _normalize_types(func, builder, lhs, rhs):
76+
"""Normalize types for comparison."""
77+
78+
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
79+
if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType):
80+
if lhs.type.width < rhs.type.width:
81+
lhs = builder.sext(lhs, rhs.type)
82+
else:
83+
rhs = builder.sext(rhs, lhs.type)
84+
return lhs, rhs
85+
elif not isinstance(lhs.type, ir.PointerType) and not isinstance(
86+
rhs.type, ir.PointerType
87+
):
88+
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
89+
return None, None
90+
else:
91+
lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type)
92+
rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type)
93+
if lhs_base == rhs_base:
94+
if lhs_depth < rhs_depth:
95+
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
96+
elif rhs_depth < lhs_depth:
97+
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
98+
return _normalize_types(func, builder, lhs, rhs)
99+
100+
101+
def convert_to_bool(builder, val):
102+
"""Convert a value to boolean."""
103+
if val.type == ir.IntType(1):
104+
return val
105+
if isinstance(val.type, ir.PointerType):
106+
zero = ir.Constant(val.type, None)
107+
else:
108+
zero = ir.Constant(val.type, 0)
109+
return builder.icmp_signed("!=", val, zero)
110+
111+
112+
def handle_comparator(func, builder, op, lhs, rhs):
113+
"""Handle comparison operations."""
114+
115+
if lhs.type != rhs.type:
116+
lhs, rhs = _normalize_types(func, builder, lhs, rhs)
117+
118+
if lhs is None or rhs is None:
119+
return None
120+
121+
if type(op) not in COMPARISON_OPS:
122+
logger.error(f"Unsupported comparison operator: {type(op)}")
123+
return None
124+
125+
predicate = COMPARISON_OPS[type(op)]
126+
result = builder.icmp_signed(predicate, lhs, rhs)
127+
logger.debug(f"Comparison result: {result}")
128+
return result, ir.IntType(1)

0 commit comments

Comments
 (0)