Skip to content

Commit 0f6971b

Browse files
committed
Refactor allocate_mem
1 parent 08c0ccf commit 0f6971b

File tree

1 file changed

+212
-132
lines changed

1 file changed

+212
-132
lines changed

pythonbpf/functions/functions_pass.py

Lines changed: 212 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,203 @@ def process_stmt(
220220
return did_return
221221

222222

223+
def _is_helper_call(call_node):
224+
"""Check if a call node is a BPF helper function call."""
225+
if isinstance(call_node.func, ast.Name):
226+
# Exclude print from requiring temps (handles f-strings differently)
227+
func_name = call_node.func.id
228+
return HelperHandlerRegistry.has_handler(func_name) and func_name != "print"
229+
230+
elif isinstance(call_node.func, ast.Attribute):
231+
return HelperHandlerRegistry.has_handler(call_node.func.attr)
232+
233+
return False
234+
235+
236+
def _handle_if_allocation(
237+
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
238+
):
239+
"""Recursively handle allocations in if/else branches."""
240+
if stmt.body:
241+
allocate_mem(
242+
module,
243+
builder,
244+
stmt.body,
245+
func,
246+
ret_type,
247+
map_sym_tab,
248+
local_sym_tab,
249+
structs_sym_tab,
250+
)
251+
if stmt.orelse:
252+
allocate_mem(
253+
module,
254+
builder,
255+
stmt.orelse,
256+
func,
257+
ret_type,
258+
map_sym_tab,
259+
local_sym_tab,
260+
structs_sym_tab,
261+
)
262+
263+
264+
def _handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab):
265+
"""Handle memory allocation for assignment statements."""
266+
267+
# Validate assignment
268+
if len(stmt.targets) != 1:
269+
logger.warning("Multi-target assignment not supported, skipping allocation")
270+
return
271+
272+
target = stmt.targets[0]
273+
274+
# Skip non-name targets (e.g., struct field assignments)
275+
if isinstance(target, ast.Attribute):
276+
logger.debug(f"Struct field assignment to {target.attr}, no allocation needed")
277+
return
278+
279+
if not isinstance(target, ast.Name):
280+
logger.warning(f"Unsupported assignment target type: {type(target).__name__}")
281+
return
282+
283+
var_name = target.id
284+
rval = stmt.value
285+
286+
# Skip if already allocated
287+
if var_name in local_sym_tab:
288+
logger.debug(f"Variable {var_name} already allocated, skipping")
289+
return
290+
291+
# Determine type and allocate based on rval
292+
if isinstance(rval, ast.Call):
293+
_allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab)
294+
elif isinstance(rval, ast.Constant):
295+
_allocate_for_constant(builder, var_name, rval, local_sym_tab)
296+
elif isinstance(rval, ast.BinOp):
297+
_allocate_for_binop(builder, var_name, local_sym_tab)
298+
else:
299+
logger.warning(
300+
f"Unsupported assignment value type for {var_name}: {type(rval).__name__}"
301+
)
302+
303+
304+
def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab):
305+
"""Allocate memory for variable assigned from a call."""
306+
307+
if isinstance(rval.func, ast.Name):
308+
call_type = rval.func.id
309+
310+
# C type constructors
311+
if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"):
312+
ir_type = ctypes_to_ir(call_type)
313+
var = builder.alloca(ir_type, name=var_name)
314+
var.align = ir_type.width // 8
315+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
316+
logger.info(f"Pre-allocated {var_name} as {call_type}")
317+
318+
# Helper functions
319+
elif HelperHandlerRegistry.has_handler(call_type):
320+
ir_type = ir.IntType(64) # Assume i64 return type
321+
var = builder.alloca(ir_type, name=var_name)
322+
var.align = 8
323+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
324+
logger.info(f"Pre-allocated {var_name} for helper {call_type}")
325+
326+
# Deref function
327+
elif call_type == "deref":
328+
ir_type = ir.IntType(64) # Assume i64 return type
329+
var = builder.alloca(ir_type, name=var_name)
330+
var.align = 8
331+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
332+
logger.info(f"Pre-allocated {var_name} for deref")
333+
334+
# Struct constructors
335+
elif call_type in structs_sym_tab:
336+
struct_info = structs_sym_tab[call_type]
337+
var = builder.alloca(struct_info.ir_type, name=var_name)
338+
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
339+
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
340+
341+
else:
342+
logger.warning(f"Unknown call type for allocation: {call_type}")
343+
344+
elif isinstance(rval.func, ast.Attribute):
345+
# Map method calls - need double allocation for ptr handling
346+
_allocate_for_map_method(builder, var_name, local_sym_tab)
347+
348+
else:
349+
logger.warning(f"Unsupported call function type for {var_name}")
350+
351+
352+
def _allocate_for_map_method(builder, var_name, local_sym_tab):
353+
"""Allocate memory for variable assigned from map method (double alloc)."""
354+
355+
# Main variable (pointer to pointer)
356+
ir_type = ir.PointerType(ir.IntType(64))
357+
var = builder.alloca(ir_type, name=var_name)
358+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
359+
360+
# Temporary variable for computed values
361+
tmp_ir_type = ir.IntType(64)
362+
var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp")
363+
local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type)
364+
365+
logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method")
366+
367+
368+
def _allocate_for_constant(builder, var_name, rval, local_sym_tab):
369+
"""Allocate memory for variable assigned from a constant."""
370+
371+
if isinstance(rval.value, bool):
372+
ir_type = ir.IntType(1)
373+
var = builder.alloca(ir_type, name=var_name)
374+
var.align = 1
375+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
376+
logger.info(f"Pre-allocated {var_name} as bool")
377+
378+
elif isinstance(rval.value, int):
379+
ir_type = ir.IntType(64)
380+
var = builder.alloca(ir_type, name=var_name)
381+
var.align = 8
382+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
383+
logger.info(f"Pre-allocated {var_name} as i64")
384+
385+
elif isinstance(rval.value, str):
386+
ir_type = ir.PointerType(ir.IntType(8))
387+
var = builder.alloca(ir_type, name=var_name)
388+
var.align = 8
389+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
390+
logger.info(f"Pre-allocated {var_name} as string")
391+
392+
else:
393+
logger.warning(
394+
f"Unsupported constant type for {var_name}: {type(rval.value).__name__}"
395+
)
396+
397+
398+
def _allocate_for_binop(builder, var_name, local_sym_tab):
399+
"""Allocate memory for variable assigned from a binary operation."""
400+
ir_type = ir.IntType(64) # Assume i64 result
401+
var = builder.alloca(ir_type, name=var_name)
402+
var.align = 8
403+
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
404+
logger.info(f"Pre-allocated {var_name} for binop result")
405+
406+
407+
def _allocate_temp_pool(builder, max_temps, local_sym_tab):
408+
"""Allocate the temporary scratch space pool for helper arguments."""
409+
if max_temps == 0:
410+
return
411+
412+
logger.info(f"Allocating temp pool of {max_temps} variables")
413+
for i in range(max_temps):
414+
temp_name = f"__helper_temp_{i}"
415+
temp_var = builder.alloca(ir.IntType(64), name=temp_name)
416+
temp_var.align = 8
417+
local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64))
418+
419+
223420
def count_temps_in_call(call_node, local_sym_tab):
224421
"""Count the number of temporary variables needed for a function call."""
225422

@@ -255,7 +452,6 @@ def count_temps_in_call(call_node, local_sym_tab):
255452
def allocate_mem(
256453
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
257454
):
258-
double_alloc = False
259455
max_temps_needed = 0
260456

261457
def update_max_temps_for_stmt(stmt):
@@ -276,139 +472,23 @@ def update_max_temps_for_stmt(stmt):
276472

277473
for stmt in body:
278474
update_max_temps_for_stmt(stmt)
279-
has_metadata = False
475+
476+
# Handle allocations
280477
if isinstance(stmt, ast.If):
281-
if stmt.body:
282-
local_sym_tab = allocate_mem(
283-
module,
284-
builder,
285-
stmt.body,
286-
func,
287-
ret_type,
288-
map_sym_tab,
289-
local_sym_tab,
290-
structs_sym_tab,
291-
)
292-
if stmt.orelse:
293-
local_sym_tab = allocate_mem(
294-
module,
295-
builder,
296-
stmt.orelse,
297-
func,
298-
ret_type,
299-
map_sym_tab,
300-
local_sym_tab,
301-
structs_sym_tab,
302-
)
478+
_handle_if_allocation(
479+
module,
480+
builder,
481+
stmt,
482+
func,
483+
ret_type,
484+
map_sym_tab,
485+
local_sym_tab,
486+
structs_sym_tab,
487+
)
303488
elif isinstance(stmt, ast.Assign):
304-
if len(stmt.targets) != 1:
305-
logger.info("Unsupported multiassignment")
306-
continue
307-
target = stmt.targets[0]
308-
if not isinstance(target, ast.Name) and not isinstance(
309-
target, ast.Attribute
310-
):
311-
logger.info("Unsupported assignment target")
312-
continue
313-
if isinstance(target, ast.Attribute):
314-
logger.info(
315-
f"Struct field {target.attr} assignment, will be handled later"
316-
)
317-
continue
318-
var_name = target.id
319-
rval = stmt.value
320-
if var_name in local_sym_tab:
321-
logger.info(f"Variable {var_name} already allocated")
322-
continue
323-
if isinstance(rval, ast.Call):
324-
if isinstance(rval.func, ast.Name):
325-
call_type = rval.func.id
326-
if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"):
327-
ir_type = ctypes_to_ir(call_type)
328-
var = builder.alloca(ir_type, name=var_name)
329-
var.align = ir_type.width // 8
330-
logger.info(
331-
f"Pre-allocated variable {var_name} of type {call_type}"
332-
)
333-
elif HelperHandlerRegistry.has_handler(call_type):
334-
# Assume return type is int64 for now
335-
ir_type = ir.IntType(64)
336-
var = builder.alloca(ir_type, name=var_name)
337-
var.align = ir_type.width // 8
338-
logger.info(f"Pre-allocated variable {var_name} for helper")
339-
elif call_type == "deref" and len(rval.args) == 1:
340-
# Assume return type is int64 for now
341-
ir_type = ir.IntType(64)
342-
var = builder.alloca(ir_type, name=var_name)
343-
var.align = ir_type.width // 8
344-
logger.info(f"Pre-allocated variable {var_name} for deref")
345-
elif call_type in structs_sym_tab:
346-
struct_info = structs_sym_tab[call_type]
347-
ir_type = struct_info.ir_type
348-
var = builder.alloca(ir_type, name=var_name)
349-
has_metadata = True
350-
logger.info(
351-
f"Pre-allocated variable {var_name} for struct {call_type}"
352-
)
353-
elif isinstance(rval.func, ast.Attribute):
354-
# Map method call
355-
ir_type = ir.PointerType(ir.IntType(64))
356-
var = builder.alloca(ir_type, name=var_name)
357-
358-
# declare an intermediate ptr type for map lookup
359-
tmp_ir_type = ir.IntType(64)
360-
var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp")
361-
double_alloc = True
362-
# var.align = ir_type.width // 8
363-
logger.info(
364-
f"Pre-allocated variable {var_name} and {var_name}_tmp for map"
365-
)
366-
else:
367-
logger.info("Unsupported assignment call function type")
368-
continue
369-
elif isinstance(rval, ast.Constant):
370-
if isinstance(rval.value, bool):
371-
ir_type = ir.IntType(1)
372-
var = builder.alloca(ir_type, name=var_name)
373-
var.align = 1
374-
logger.info(f"Pre-allocated variable {var_name} of type c_bool")
375-
elif isinstance(rval.value, int):
376-
# Assume c_int64 for now
377-
ir_type = ir.IntType(64)
378-
var = builder.alloca(ir_type, name=var_name)
379-
var.align = ir_type.width // 8
380-
logger.info(f"Pre-allocated variable {var_name} of type c_int64")
381-
elif isinstance(rval.value, str):
382-
ir_type = ir.PointerType(ir.IntType(8))
383-
var = builder.alloca(ir_type, name=var_name)
384-
var.align = 8
385-
logger.info(f"Pre-allocated variable {var_name} of type string")
386-
else:
387-
logger.info("Unsupported constant type")
388-
continue
389-
elif isinstance(rval, ast.BinOp):
390-
# Assume c_int64 for now
391-
ir_type = ir.IntType(64)
392-
var = builder.alloca(ir_type, name=var_name)
393-
var.align = ir_type.width // 8
394-
logger.info(f"Pre-allocated variable {var_name} of type c_int64")
395-
else:
396-
logger.info("Unsupported assignment value type")
397-
continue
398-
399-
if has_metadata:
400-
local_sym_tab[var_name] = LocalSymbol(var, ir_type, call_type)
401-
else:
402-
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
403-
404-
if double_alloc:
405-
local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type)
406-
407-
logger.info(f"Temporary scratch space needed for calls: {max_temps_needed}")
408-
for i in range(max_temps_needed):
409-
temp_var = builder.alloca(ir.IntType(64), name=f"__helper_temp_{i}")
410-
temp_var.align = 8
411-
local_sym_tab[f"__helper_temp_{i}"] = LocalSymbol(temp_var, ir.IntType(64))
489+
_handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab)
490+
491+
_allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
412492

413493
return local_sym_tab
414494

0 commit comments

Comments
 (0)