Skip to content

Commit 123a92a

Browse files
committed
Change allocation pass to generate typed temp variables
1 parent 752f564 commit 123a92a

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

pythonbpf/allocation_pass.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,17 +199,33 @@ def _allocate_for_binop(builder, var_name, local_sym_tab):
199199
logger.info(f"Pre-allocated {var_name} for binop result")
200200

201201

202+
def _get_type_name(ir_type):
203+
"""Get a string representation of an IR type."""
204+
if isinstance(ir_type, ir.IntType):
205+
return f"i{ir_type.width}"
206+
elif isinstance(ir_type, ir.PointerType):
207+
return "ptr"
208+
elif isinstance(ir_type, ir.ArrayType):
209+
return f"[{ir_type.count}x{_get_type_name(ir_type.element)}]"
210+
else:
211+
return str(ir_type).replace(" ", "")
212+
213+
202214
def allocate_temp_pool(builder, max_temps, local_sym_tab):
203215
"""Allocate the temporary scratch space pool for helper arguments."""
204-
if max_temps == 0:
216+
if not max_temps:
217+
logger.info("No temp pool allocation needed")
205218
return
206219

207-
logger.info(f"Allocating temp pool of {max_temps} variables")
208-
for i in range(max_temps):
209-
temp_name = f"__helper_temp_{i}"
210-
temp_var = builder.alloca(ir.IntType(64), name=temp_name)
211-
temp_var.align = 8
212-
local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64))
220+
for tmp_type, cnt in max_temps.items():
221+
type_name = _get_type_name(tmp_type)
222+
logger.info(f"Allocating temp pool of {cnt} variables of type {type_name}")
223+
for i in range(cnt):
224+
temp_name = f"__helper_temp_{type_name}_{i}"
225+
temp_var = builder.alloca(tmp_type, name=temp_name)
226+
temp_var.align = _get_alignment(tmp_type)
227+
local_sym_tab[temp_name] = LocalSymbol(temp_var, tmp_type)
228+
logger.debug(f"Allocated temp variable: {temp_name}")
213229

214230

215231
def _allocate_for_name(builder, var_name, rval, local_sym_tab):

pythonbpf/functions/functions_pass.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,15 @@ def handle_if_allocation(
9898
def allocate_mem(
9999
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
100100
):
101-
max_temps_needed = 0
101+
max_temps_needed = {}
102+
103+
def merge_type_counts(count_dict):
104+
nonlocal max_temps_needed
105+
for typ, cnt in count_dict.items():
106+
max_temps_needed[typ] = max(max_temps_needed.get(typ, 0), cnt)
102107

103108
def update_max_temps_for_stmt(stmt):
104109
nonlocal max_temps_needed
105-
temps_needed = 0
106110

107111
if isinstance(stmt, ast.If):
108112
for s in stmt.body:
@@ -111,10 +115,13 @@ def update_max_temps_for_stmt(stmt):
111115
update_max_temps_for_stmt(s)
112116
return
113117

118+
stmt_temps = {}
114119
for node in ast.walk(stmt):
115120
if isinstance(node, ast.Call):
116-
temps_needed += count_temps_in_call(node, local_sym_tab)
117-
max_temps_needed = max(max_temps_needed, temps_needed)
121+
call_temps = count_temps_in_call(node, local_sym_tab)
122+
for typ, cnt in call_temps.items():
123+
stmt_temps[typ] = stmt_temps.get(typ, 0) + cnt
124+
merge_type_counts(stmt_temps)
118125

119126
for stmt in body:
120127
update_max_temps_for_stmt(stmt)

0 commit comments

Comments
 (0)