Skip to content

Commit 31645f0

Browse files
authored
Merge pull request #40 from pythonbpf/refactor_assign
Refactor assignment statement handling and the typing mechanism around it
2 parents 21ce041 + e0ad1bf commit 31645f0

File tree

15 files changed

+872
-368
lines changed

15 files changed

+872
-368
lines changed

pythonbpf/allocation_pass.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import ast
2+
import logging
3+
4+
from llvmlite import ir
5+
from dataclasses import dataclass
6+
from typing import Any
7+
from pythonbpf.helper import HelperHandlerRegistry
8+
from pythonbpf.type_deducer import ctypes_to_ir
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
@dataclass
14+
class LocalSymbol:
15+
var: ir.AllocaInstr
16+
ir_type: ir.Type
17+
metadata: Any = None
18+
19+
def __iter__(self):
20+
yield self.var
21+
yield self.ir_type
22+
yield self.metadata
23+
24+
25+
def _is_helper_call(call_node):
26+
"""Check if a call node is a BPF helper function call."""
27+
if isinstance(call_node.func, ast.Name):
28+
# Exclude print from requiring temps (handles f-strings differently)
29+
func_name = call_node.func.id
30+
return HelperHandlerRegistry.has_handler(func_name) and func_name != "print"
31+
32+
elif isinstance(call_node.func, ast.Attribute):
33+
return HelperHandlerRegistry.has_handler(call_node.func.attr)
34+
35+
return False
36+
37+
38+
def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab):
39+
"""Handle memory allocation for assignment statements."""
40+
41+
# Validate assignment
42+
if len(stmt.targets) != 1:
43+
logger.warning("Multi-target assignment not supported, skipping allocation")
44+
return
45+
46+
target = stmt.targets[0]
47+
48+
# Skip non-name targets (e.g., struct field assignments)
49+
if isinstance(target, ast.Attribute):
50+
logger.debug(f"Struct field assignment to {target.attr}, no allocation needed")
51+
return
52+
53+
if not isinstance(target, ast.Name):
54+
logger.warning(f"Unsupported assignment target type: {type(target).__name__}")
55+
return
56+
57+
var_name = target.id
58+
rval = stmt.value
59+
60+
# Skip if already allocated
61+
if var_name in local_sym_tab:
62+
logger.debug(f"Variable {var_name} already allocated, skipping")
63+
return
64+
65+
# Determine type and allocate based on rval
66+
if isinstance(rval, ast.Call):
67+
_allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab)
68+
elif isinstance(rval, ast.Constant):
69+
_allocate_for_constant(builder, var_name, rval, local_sym_tab)
70+
elif isinstance(rval, ast.BinOp):
71+
_allocate_for_binop(builder, var_name, local_sym_tab)
72+
else:
73+
logger.warning(
74+
f"Unsupported assignment value type for {var_name}: {type(rval).__name__}"
75+
)
76+
77+
78+
def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab):
79+
"""Allocate memory for variable assigned from a call."""
80+
81+
if isinstance(rval.func, ast.Name):
82+
call_type = rval.func.id
83+
84+
# C type constructors
85+
if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"):
86+
ir_type = ctypes_to_ir(call_type)
87+
var = builder.alloca(ir_type, name=var_name)
88+
var.align = ir_type.width // 8
89+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
90+
logger.info(f"Pre-allocated {var_name} as {call_type}")
91+
92+
# Helper functions
93+
elif HelperHandlerRegistry.has_handler(call_type):
94+
ir_type = ir.IntType(64) # Assume i64 return type
95+
var = builder.alloca(ir_type, name=var_name)
96+
var.align = 8
97+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
98+
logger.info(f"Pre-allocated {var_name} for helper {call_type}")
99+
100+
# Deref function
101+
elif call_type == "deref":
102+
ir_type = ir.IntType(64) # Assume i64 return type
103+
var = builder.alloca(ir_type, name=var_name)
104+
var.align = 8
105+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
106+
logger.info(f"Pre-allocated {var_name} for deref")
107+
108+
# Struct constructors
109+
elif call_type in structs_sym_tab:
110+
struct_info = structs_sym_tab[call_type]
111+
var = builder.alloca(struct_info.ir_type, name=var_name)
112+
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
113+
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
114+
115+
else:
116+
logger.warning(f"Unknown call type for allocation: {call_type}")
117+
118+
elif isinstance(rval.func, ast.Attribute):
119+
# Map method calls - need double allocation for ptr handling
120+
_allocate_for_map_method(builder, var_name, local_sym_tab)
121+
122+
else:
123+
logger.warning(f"Unsupported call function type for {var_name}")
124+
125+
126+
def _allocate_for_map_method(builder, var_name, local_sym_tab):
127+
"""Allocate memory for variable assigned from map method (double alloc)."""
128+
129+
# Main variable (pointer to pointer)
130+
ir_type = ir.PointerType(ir.IntType(64))
131+
var = builder.alloca(ir_type, name=var_name)
132+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
133+
134+
# Temporary variable for computed values
135+
tmp_ir_type = ir.IntType(64)
136+
var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp")
137+
local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type)
138+
139+
logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method")
140+
141+
142+
def _allocate_for_constant(builder, var_name, rval, local_sym_tab):
143+
"""Allocate memory for variable assigned from a constant."""
144+
145+
if isinstance(rval.value, bool):
146+
ir_type = ir.IntType(1)
147+
var = builder.alloca(ir_type, name=var_name)
148+
var.align = 1
149+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
150+
logger.info(f"Pre-allocated {var_name} as bool")
151+
152+
elif isinstance(rval.value, int):
153+
ir_type = ir.IntType(64)
154+
var = builder.alloca(ir_type, name=var_name)
155+
var.align = 8
156+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
157+
logger.info(f"Pre-allocated {var_name} as i64")
158+
159+
elif isinstance(rval.value, str):
160+
ir_type = ir.PointerType(ir.IntType(8))
161+
var = builder.alloca(ir_type, name=var_name)
162+
var.align = 8
163+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
164+
logger.info(f"Pre-allocated {var_name} as string")
165+
166+
else:
167+
logger.warning(
168+
f"Unsupported constant type for {var_name}: {type(rval.value).__name__}"
169+
)
170+
171+
172+
def _allocate_for_binop(builder, var_name, local_sym_tab):
173+
"""Allocate memory for variable assigned from a binary operation."""
174+
ir_type = ir.IntType(64) # Assume i64 result
175+
var = builder.alloca(ir_type, name=var_name)
176+
var.align = 8
177+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
178+
logger.info(f"Pre-allocated {var_name} for binop result")
179+
180+
181+
def allocate_temp_pool(builder, max_temps, local_sym_tab):
182+
"""Allocate the temporary scratch space pool for helper arguments."""
183+
if max_temps == 0:
184+
return
185+
186+
logger.info(f"Allocating temp pool of {max_temps} variables")
187+
for i in range(max_temps):
188+
temp_name = f"__helper_temp_{i}"
189+
temp_var = builder.alloca(ir.IntType(64), name=temp_name)
190+
temp_var.align = 8
191+
local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64))

pythonbpf/assign_pass.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import ast
2+
import logging
3+
from llvmlite import ir
4+
from pythonbpf.expr import eval_expr
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def handle_struct_field_assignment(
10+
func, module, builder, target, rval, local_sym_tab, map_sym_tab, structs_sym_tab
11+
):
12+
"""Handle struct field assignment (obj.field = value)."""
13+
14+
var_name = target.value.id
15+
field_name = target.attr
16+
17+
if var_name not in local_sym_tab:
18+
logger.error(f"Variable '{var_name}' not found in symbol table")
19+
return
20+
21+
struct_type = local_sym_tab[var_name].metadata
22+
struct_info = structs_sym_tab[struct_type]
23+
24+
if field_name not in struct_info.fields:
25+
logger.error(f"Field '{field_name}' not found in struct '{struct_type}'")
26+
return
27+
28+
# Get field pointer and evaluate value
29+
field_ptr = struct_info.gep(builder, local_sym_tab[var_name].var, field_name)
30+
val = eval_expr(
31+
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
32+
)
33+
34+
if val is None:
35+
logger.error(f"Failed to evaluate value for {var_name}.{field_name}")
36+
return
37+
38+
# TODO: Handle string assignment to char array (not a priority)
39+
field_type = struct_info.field_type(field_name)
40+
if isinstance(field_type, ir.ArrayType) and val[1] == ir.PointerType(ir.IntType(8)):
41+
logger.warning(
42+
f"String to char array assignment not implemented for {var_name}.{field_name}"
43+
)
44+
return
45+
46+
# Store the value
47+
builder.store(val[0], field_ptr)
48+
logger.info(f"Assigned to struct field {var_name}.{field_name}")
49+
50+
51+
def handle_variable_assignment(
52+
func, module, builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
53+
):
54+
"""Handle single named variable assignment."""
55+
56+
if var_name not in local_sym_tab:
57+
logger.error(f"Variable {var_name} not declared.")
58+
return False
59+
60+
var_ptr = local_sym_tab[var_name].var
61+
var_type = local_sym_tab[var_name].ir_type
62+
63+
# NOTE: Special case for struct initialization
64+
if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name):
65+
struct_name = rval.func.id
66+
if struct_name in structs_sym_tab and len(rval.args) == 0:
67+
struct_info = structs_sym_tab[struct_name]
68+
ir_struct = struct_info.ir_type
69+
70+
builder.store(ir.Constant(ir_struct, None), var_ptr)
71+
logger.info(f"Initialized struct {struct_name} for variable {var_name}")
72+
return True
73+
74+
val_result = eval_expr(
75+
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
76+
)
77+
if val_result is None:
78+
logger.error(f"Failed to evaluate value for {var_name}")
79+
return False
80+
81+
val, val_type = val_result
82+
logger.info(f"Evaluated value for {var_name}: {val} of type {val_type}, {var_type}")
83+
if val_type != var_type:
84+
if isinstance(val_type, ir.IntType) and isinstance(var_type, ir.IntType):
85+
# Allow implicit int widening
86+
if val_type.width < var_type.width:
87+
val = builder.sext(val, var_type)
88+
logger.info(f"Implicitly widened int for variable {var_name}")
89+
elif val_type.width > var_type.width:
90+
val = builder.trunc(val, var_type)
91+
logger.info(f"Implicitly truncated int for variable {var_name}")
92+
elif isinstance(val_type, ir.IntType) and isinstance(var_type, ir.PointerType):
93+
# NOTE: This is assignment to a PTR_TO_MAP_VALUE_OR_NULL
94+
logger.info(
95+
f"Creating temporary variable for pointer assignment to {var_name}"
96+
)
97+
var_ptr_tmp = local_sym_tab[f"{var_name}_tmp"].var
98+
builder.store(val, var_ptr_tmp)
99+
val = var_ptr_tmp
100+
else:
101+
logger.error(
102+
f"Type mismatch for variable {var_name}: {val_type} vs {var_type}"
103+
)
104+
return False
105+
106+
builder.store(val, var_ptr)
107+
logger.info(f"Assigned value to variable {var_name}")
108+
return True

0 commit comments

Comments
 (0)