Skip to content
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
6008d98
Change loglevel of multi-assignment warning in handle_assign
r41k0u Oct 8, 2025
84ed27f
Add handle_variable_assignment stub and boilerplate in handle_assign
r41k0u Oct 8, 2025
d7bfe86
Add handle_variable_assignment to assign_pass
r41k0u Oct 8, 2025
054a834
Add failing assign test retype.py, with explanation
r41k0u Oct 8, 2025
c596213
Add cst_var_binop.py as passing assign test
r41k0u Oct 8, 2025
23afb0b
Add deref_to_val to deref into final value and return the chain as we…
r41k0u Oct 9, 2025
1253f51
Use deref_to_val instead of recursive_dereferencer in get_operand value
r41k0u Oct 9, 2025
8bab07e
Remove recursive_dereferencer
r41k0u Oct 9, 2025
489244a
Add store_through_chain
r41k0u Oct 9, 2025
047f361
Allocate twice for map lookups
r41k0u Oct 10, 2025
1d517d4
Add double_alloc in alloc_mem
r41k0u Oct 10, 2025
99aacca
WIP: allow pointer assignments to var
r41k0u Oct 10, 2025
9febadf
Add pointer handling to helper_utils, finish pointer assignment
r41k0u Oct 10, 2025
7529820
Allow int** pointers to store binops of type int** op int
r41k0u Oct 10, 2025
a756f5e
Add passing helper test for assignment
r41k0u Oct 10, 2025
3175756
Interpret bools as ints in binops
r41k0u Oct 10, 2025
cac88d1
Allow different int widths in binops
r41k0u Oct 10, 2025
c2c1774
Remove store_through_chain
r41k0u Oct 10, 2025
91a3fe1
Remove unnecessary return artifacts from get_operand_value
r41k0u Oct 10, 2025
c9bbe1f
Call eval_expr properly within get_operand_value
r41k0u Oct 10, 2025
8b7b1c0
Add struct_and_helper_binops passing test for assignments
r41k0u Oct 11, 2025
8776d76
Add count_temps_in_call to call scratch space needed in a helper call
r41k0u Oct 11, 2025
321415f
Add update_max_temps_for_stmt in allocate_mem
r41k0u Oct 11, 2025
6bce29b
Allocate scratch space for temp vars at the end of allocate_mem
r41k0u Oct 11, 2025
5dcf670
Add ScratchPoolManager and it's singleton
r41k0u Oct 11, 2025
207f714
Use scratch space to store consts passed to helpers
r41k0u Oct 11, 2025
cd74e89
Allow binops as args to helpers accepting int*
r41k0u Oct 11, 2025
d66e6a6
Allow struct members as helper args
r41k0u Oct 12, 2025
2cf68f6
Allow map-based helpers to be used as helper args / within binops whi…
r41k0u Oct 12, 2025
4e33fd4
Add negation UnaryOp
r41k0u Oct 12, 2025
a3b4d09
Fix errorstring in _handle_unary_op
r41k0u Oct 12, 2025
e8026a1
Allow helpers to be called within themselves
r41k0u Oct 12, 2025
fa82dc7
Add comprehensive passing test for assignment
r41k0u Oct 12, 2025
b93f704
Tweak the comprehensive assignment test
r41k0u Oct 12, 2025
933d2a5
Fix comprehensive assignment test
r41k0u Oct 12, 2025
105c5a7
Cleanup handle_assign
r41k0u Oct 12, 2025
3ad1b73
Add handle_struct_field_assignment to assign_pass
r41k0u Oct 12, 2025
64e44d0
Use handle_struct_field_assignment in handle_assign
r41k0u Oct 12, 2025
08c0ccf
Pass map_sym_tab to handle_struct_field_assign
r41k0u Oct 12, 2025
0f6971b
Refactor allocate_mem
r41k0u Oct 12, 2025
2f1aaa4
Fix typos
r41k0u Oct 12, 2025
69bee5f
Seperate LocalSymbol from functions
r41k0u Oct 12, 2025
e0ad1bf
Move bulk of allocation logic to allocation_pass
r41k0u Oct 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions pythonbpf/assign_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import ast
import logging
from llvmlite import ir
from pythonbpf.expr import eval_expr

logger = logging.getLogger(__name__)


def handle_struct_field_assignment(
func, module, builder, target, rval, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Handle struct field assignment (obj.field = value)."""

var_name = target.value.id
field_name = target.attr

if var_name not in local_sym_tab:
logger.error(f"Variable '{var_name}' not found in symbol table")
return

struct_type = local_sym_tab[var_name].metadata
struct_info = structs_sym_tab[struct_type]

if field_name not in struct_info.fields:
logger.error(f"Field '{field_name}' not found in struct '{struct_type}'")
return

# Get field pointer and evaluate value
field_ptr = struct_info.gep(builder, local_sym_tab[var_name].var, field_name)
val = eval_expr(
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
)

if val is None:
logger.error(f"Failed to evaluate value for {var_name}.{field_name}")
return

# TODO: Handle string assignment to char array (not a priority)
field_type = struct_info.field_type(field_name)
if isinstance(field_type, ir.ArrayType) and val[1] == ir.PointerType(ir.IntType(8)):
logger.warning(
f"String to char array assignment not implemented for {var_name}.{field_name}"
)
return

# Store the value
builder.store(val[0], field_ptr)
logger.info(f"Assigned to struct field {var_name}.{field_name}")


def handle_variable_assignment(
func, module, builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Handle single named variable assignment."""

if var_name not in local_sym_tab:
logger.error(f"Variable {var_name} not declared.")
return False

var_ptr = local_sym_tab[var_name].var
var_type = local_sym_tab[var_name].ir_type

# NOTE: Special case for struct initialization
if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name):
struct_name = rval.func.id
if struct_name in structs_sym_tab and len(rval.args) == 0:
struct_info = structs_sym_tab[struct_name]
ir_struct = struct_info.ir_type

builder.store(ir.Constant(ir_struct, None), var_ptr)
logger.info(f"Initialized struct {struct_name} for variable {var_name}")
return True

val_result = eval_expr(
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
)
if val_result is None:
logger.error(f"Failed to evaluate value for {var_name}")
return False

val, val_type = val_result
logger.info(f"Evaluated value for {var_name}: {val} of type {val_type}, {var_type}")
if val_type != var_type:
if isinstance(val_type, ir.IntType) and isinstance(var_type, ir.IntType):
# Allow implicit int widening
if val_type.width < var_type.width:
val = builder.sext(val, var_type)
logger.info(f"Implicitly widened int for variable {var_name}")
elif val_type.width > var_type.width:
val = builder.trunc(val, var_type)
logger.info(f"Implicitly truncated int for variable {var_name}")
elif isinstance(val_type, ir.IntType) and isinstance(var_type, ir.PointerType):
# NOTE: This is assignment to a PTR_TO_MAP_VALUE_OR_NULL
logger.info(
f"Creating temporary variable for pointer assignment to {var_name}"
)
var_ptr_tmp = local_sym_tab[f"{var_name}_tmp"].var
builder.store(val, var_ptr_tmp)
val = var_ptr_tmp
else:
logger.error(
f"Type mismatch for variable {var_name}: {val_type} vs {var_type}"
)
return False

builder.store(val, var_ptr)
logger.info(f"Assigned value to variable {var_name}")
return True
82 changes: 60 additions & 22 deletions pythonbpf/binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,70 @@
from logging import Logger
import logging

logger: Logger = logging.getLogger(__name__)

from pythonbpf.expr import get_base_type_and_depth, deref_to_depth, eval_expr

def recursive_dereferencer(var, builder):
"""dereference until primitive type comes out"""
# TODO: Not worrying about stack overflow for now
logger.info(f"Dereferencing {var}, type is {var.type}")
if isinstance(var.type, ir.PointerType):
a = builder.load(var)
return recursive_dereferencer(a, builder)
elif isinstance(var.type, ir.IntType):
return var
else:
raise TypeError(f"Unsupported type for dereferencing: {var.type}")
logger: Logger = logging.getLogger(__name__)


def get_operand_value(operand, builder, local_sym_tab):
def get_operand_value(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
"""Extract the value from an operand, handling variables and constants."""
logger.info(f"Getting operand value for: {ast.dump(operand)}")
if isinstance(operand, ast.Name):
if operand.id in local_sym_tab:
return recursive_dereferencer(local_sym_tab[operand.id].var, builder)
var = local_sym_tab[operand.id].var
var_type = var.type
base_type, depth = get_base_type_and_depth(var_type)
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
val = deref_to_depth(func, builder, var, depth)
return val
raise ValueError(f"Undefined variable: {operand.id}")
elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int):
return ir.Constant(ir.IntType(64), operand.value)
cst = ir.Constant(ir.IntType(64), int(operand.value))
return cst
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp):
return handle_binary_op_impl(operand, builder, local_sym_tab)
res = handle_binary_op_impl(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
return res
else:
res = eval_expr(
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
)
if res is None:
raise ValueError(f"Failed to evaluate call expression: {operand}")
val, _ = res
logger.info(f"Evaluated expr to {val} of type {val.type}")
base_type, depth = get_base_type_and_depth(val.type)
if depth > 0:
val = deref_to_depth(func, builder, val, depth)
return val
raise TypeError(f"Unsupported operand type: {type(operand)}")


def handle_binary_op_impl(rval, builder, local_sym_tab):
def handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
op = rval.op
left = get_operand_value(rval.left, builder, local_sym_tab)
right = get_operand_value(rval.right, builder, local_sym_tab)
left = get_operand_value(
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
right = get_operand_value(
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
logger.info(f"left is {left}, right is {right}, op is {op}")

# NOTE: Before doing the operation, if the operands are integers
# we always extend them to i64. The assignment to LHS will take
# care of truncation if needed.
if isinstance(left.type, ir.IntType) and left.type.width < 64:
left = builder.sext(left, ir.IntType(64))
if isinstance(right.type, ir.IntType) and right.type.width < 64:
right = builder.sext(right, ir.IntType(64))

# Map AST operation nodes to LLVM IR builder methods
op_map = {
ast.Add: builder.add,
Expand All @@ -62,8 +89,19 @@ def handle_binary_op_impl(rval, builder, local_sym_tab):
raise SyntaxError("Unsupported binary operation")


def handle_binary_op(rval, builder, var_name, local_sym_tab):
result = handle_binary_op_impl(rval, builder, local_sym_tab)
def handle_binary_op(
func,
module,
rval,
builder,
var_name,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
result = handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
if var_name and var_name in local_sym_tab:
logger.info(
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
Expand Down
10 changes: 8 additions & 2 deletions pythonbpf/expr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .expr_pass import eval_expr, handle_expr
from .type_normalization import convert_to_bool
from .type_normalization import convert_to_bool, get_base_type_and_depth, deref_to_depth

__all__ = ["eval_expr", "handle_expr", "convert_to_bool"]
__all__ = [
"eval_expr",
"handle_expr",
"convert_to_bool",
"get_base_type_and_depth",
"deref_to_depth",
]
36 changes: 26 additions & 10 deletions pythonbpf/expr/expr_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _handle_constant_expr(expr: ast.Constant):
if isinstance(expr.value, int) or isinstance(expr.value, bool):
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
else:
logger.error("Unsupported constant type")
logger.error(f"Unsupported constant type {ast.dump(expr)}")
return None


Expand Down Expand Up @@ -176,21 +176,28 @@ def _handle_unary_op(
structs_sym_tab=None,
):
"""Handle ast.UnaryOp expressions."""
if not isinstance(expr.op, ast.Not):
logger.error("Only 'not' unary operator is supported")
if not isinstance(expr.op, ast.Not) and not isinstance(expr.op, ast.USub):
logger.error("Only 'not' and '-' unary operators are supported")
return None

operand = eval_expr(
func, module, builder, expr.operand, local_sym_tab, map_sym_tab, structs_sym_tab
from pythonbpf.binary_ops import get_operand_value

operand = get_operand_value(
func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
if operand is None:
logger.error("Failed to evaluate operand for unary operation")
return None

operand_val, operand_type = operand
true_const = ir.Constant(ir.IntType(1), 1)
result = builder.xor(convert_to_bool(builder, operand_val), true_const)
return result, ir.IntType(1)
if isinstance(expr.op, ast.Not):
true_const = ir.Constant(ir.IntType(1), 1)
result = builder.xor(convert_to_bool(builder, operand), true_const)
return result, ir.IntType(1)
elif isinstance(expr.op, ast.USub):
# Multiply by -1
neg_one = ir.Constant(ir.IntType(64), -1)
result = builder.mul(operand, neg_one)
return result, ir.IntType(64)


def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
Expand Down Expand Up @@ -402,7 +409,16 @@ def eval_expr(
elif isinstance(expr, ast.BinOp):
from pythonbpf.binary_ops import handle_binary_op

return handle_binary_op(expr, builder, None, local_sym_tab)
return handle_binary_op(
func,
module,
expr,
builder,
None,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(expr, ast.Compare):
return _handle_compare(
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
Expand Down
12 changes: 6 additions & 6 deletions pythonbpf/expr/type_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
}


def _get_base_type_and_depth(ir_type):
def get_base_type_and_depth(ir_type):
"""Get the base type for pointer types."""
cur_type = ir_type
depth = 0
Expand All @@ -26,7 +26,7 @@ def _get_base_type_and_depth(ir_type):
return cur_type, depth


def _deref_to_depth(func, builder, val, target_depth):
def deref_to_depth(func, builder, val, target_depth):
"""Dereference a pointer to a certain depth."""

cur_val = val
Expand Down Expand Up @@ -88,13 +88,13 @@ def _normalize_types(func, builder, lhs, rhs):
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
return None, None
else:
lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type)
rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type)
lhs_base, lhs_depth = get_base_type_and_depth(lhs.type)
rhs_base, rhs_depth = get_base_type_and_depth(rhs.type)
if lhs_base == rhs_base:
if lhs_depth < rhs_depth:
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
rhs = deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
elif rhs_depth < lhs_depth:
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
lhs = deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
return _normalize_types(func, builder, lhs, rhs)


Expand Down
Loading