Skip to content

Commit 963e2a8

Browse files
committed
Change ScratchPoolManager to use typed scratch space
1 parent 123a92a commit 963e2a8

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

pythonbpf/functions/functions_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def count_temps_in_call(call_node, local_sym_tab):
5050
func_name = call_node.func.attr
5151

5252
if not is_helper:
53-
return 0
53+
return {} # No temps needed
5454

5555
for arg_idx in range(len(call_node.args)):
5656
# NOTE: Count all non-name arguments

pythonbpf/helper/helper_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_signature(cls, helper_name):
5050
def get_param_type(cls, helper_name, index):
5151
"""Get the type of a parameter of a helper function by the index"""
5252
signature = cls.get_signature(helper_name)
53-
if signature and 0 <= index < len(signature.arg_types):
53+
if signature and signature.arg_types and 0 <= index < len(signature.arg_types):
5454
return signature.arg_types[index]
5555
return None
5656

pythonbpf/helper/helper_utils.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,43 @@ class ScratchPoolManager:
1414
"""Manage the temporary helper variables in local_sym_tab"""
1515

1616
def __init__(self):
17-
self._counter = 0
17+
self._counters = {}
1818

1919
@property
2020
def counter(self):
21-
return self._counter
21+
return sum(self._counter.values())
2222

2323
def reset(self):
24-
self._counter = 0
24+
self._counters.clear()
2525
logger.debug("Scratch pool counter reset to 0")
2626

27-
def get_next_temp(self, local_sym_tab):
28-
temp_name = f"__helper_temp_{self._counter}"
29-
self._counter += 1
27+
def _get_type_name(self, ir_type):
28+
if isinstance(ir_type, ir.PointerType):
29+
return "ptr"
30+
elif isinstance(ir_type, ir.IntType):
31+
return f"i{ir_type.width}"
32+
elif isinstance(ir_type, ir.ArrayType):
33+
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
34+
else:
35+
return str(ir_type).replace(" ", "")
36+
37+
def get_next_temp(self, local_sym_tab, expected_type=None):
38+
# Default to i64 if no expected type provided
39+
type_name = self._get_type_name(expected_type) if expected_type else "i64"
40+
if type_name not in self._counters:
41+
self._counters[type_name] = 0
42+
43+
counter = self._counters[type_name]
44+
temp_name = f"__helper_temp_{type_name}_{counter}"
45+
self._counters[type_name] += 1
3046

3147
if temp_name not in local_sym_tab:
3248
raise ValueError(
3349
f"Scratch pool exhausted or inadequate: {temp_name}. "
34-
f"Current counter: {self._counter}"
50+
f"Type: {type_name} Counter: {counter}"
3551
)
3652

53+
logger.debug(f"Using {temp_name} for type {type_name}")
3754
return local_sym_tab[temp_name].var, temp_name
3855

3956

0 commit comments

Comments
 (0)