Skip to content

Commit 5a8b64f

Browse files
authored
Merge pull request #64 from pythonbpf/all_helpers
Add support for all eBPF helpers
2 parents faad355 + cf99b3b commit 5a8b64f

File tree

12 files changed

+678
-95
lines changed

12 files changed

+678
-95
lines changed

pythonbpf/allocation_pass.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,33 @@ def _allocate_for_binop(builder, var_name, local_sym_tab):
177177
logger.info(f"Pre-allocated {var_name} for binop result")
178178

179179

180+
def _get_type_name(ir_type):
181+
"""Get a string representation of an IR type."""
182+
if isinstance(ir_type, ir.IntType):
183+
return f"i{ir_type.width}"
184+
elif isinstance(ir_type, ir.PointerType):
185+
return "ptr"
186+
elif isinstance(ir_type, ir.ArrayType):
187+
return f"[{ir_type.count}x{_get_type_name(ir_type.element)}]"
188+
else:
189+
return str(ir_type).replace(" ", "")
190+
191+
180192
def allocate_temp_pool(builder, max_temps, local_sym_tab):
181193
"""Allocate the temporary scratch space pool for helper arguments."""
182-
if max_temps == 0:
194+
if not max_temps:
195+
logger.info("No temp pool allocation needed")
183196
return
184197

185-
logger.info(f"Allocating temp pool of {max_temps} variables")
186-
for i in range(max_temps):
187-
temp_name = f"__helper_temp_{i}"
188-
temp_var = builder.alloca(ir.IntType(64), name=temp_name)
189-
temp_var.align = 8
190-
local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64))
198+
for tmp_type, cnt in max_temps.items():
199+
type_name = _get_type_name(tmp_type)
200+
logger.info(f"Allocating temp pool of {cnt} variables of type {type_name}")
201+
for i in range(cnt):
202+
temp_name = f"__helper_temp_{type_name}_{i}"
203+
temp_var = builder.alloca(tmp_type, name=temp_name)
204+
temp_var.align = _get_alignment(tmp_type)
205+
local_sym_tab[temp_name] = LocalSymbol(temp_var, tmp_type)
206+
logger.debug(f"Allocated temp variable: {temp_name}")
191207

192208

193209
def _allocate_for_name(builder, var_name, rval, local_sym_tab):

pythonbpf/functions/functions_pass.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
def count_temps_in_call(call_node, local_sym_tab):
4040
"""Count the number of temporary variables needed for a function call."""
4141

42-
count = 0
42+
count = {}
4343
is_helper = False
4444

4545
# NOTE: We exclude print calls for now
@@ -49,21 +49,28 @@ def count_temps_in_call(call_node, local_sym_tab):
4949
and call_node.func.id != "print"
5050
):
5151
is_helper = True
52+
func_name = call_node.func.id
5253
elif isinstance(call_node.func, ast.Attribute):
5354
if HelperHandlerRegistry.has_handler(call_node.func.attr):
5455
is_helper = True
56+
func_name = call_node.func.attr
5557

5658
if not is_helper:
57-
return 0
59+
return {} # No temps needed
5860

59-
for arg in call_node.args:
61+
for arg_idx in range(len(call_node.args)):
6062
# NOTE: Count all non-name arguments
6163
# For struct fields, if it is being passed as an argument,
6264
# The struct object should already exist in the local_sym_tab
63-
if not isinstance(arg, ast.Name) and not (
65+
arg = call_node.args[arg_idx]
66+
if isinstance(arg, ast.Name) or (
6467
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
6568
):
66-
count += 1
69+
continue
70+
param_type = HelperHandlerRegistry.get_param_type(func_name, arg_idx)
71+
if isinstance(param_type, ir.PointerType):
72+
pointee_type = param_type.pointee
73+
count[pointee_type] = count.get(pointee_type, 0) + 1
6774

6875
return count
6976

@@ -99,11 +106,15 @@ def handle_if_allocation(
99106
def allocate_mem(
100107
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
101108
):
102-
max_temps_needed = 0
109+
max_temps_needed = {}
110+
111+
def merge_type_counts(count_dict):
112+
nonlocal max_temps_needed
113+
for typ, cnt in count_dict.items():
114+
max_temps_needed[typ] = max(max_temps_needed.get(typ, 0), cnt)
103115

104116
def update_max_temps_for_stmt(stmt):
105117
nonlocal max_temps_needed
106-
temps_needed = 0
107118

108119
if isinstance(stmt, ast.If):
109120
for s in stmt.body:
@@ -112,10 +123,13 @@ def update_max_temps_for_stmt(stmt):
112123
update_max_temps_for_stmt(s)
113124
return
114125

126+
stmt_temps = {}
115127
for node in ast.walk(stmt):
116128
if isinstance(node, ast.Call):
117-
temps_needed += count_temps_in_call(node, local_sym_tab)
118-
max_temps_needed = max(max_temps_needed, temps_needed)
129+
call_temps = count_temps_in_call(node, local_sym_tab)
130+
for typ, cnt in call_temps.items():
131+
stmt_temps[typ] = stmt_temps.get(typ, 0) + cnt
132+
merge_type_counts(stmt_temps)
119133

120134
for stmt in body:
121135
update_max_temps_for_stmt(stmt)

pythonbpf/helper/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
from .helper_registry import HelperHandlerRegistry
22
from .helper_utils import reset_scratch_pool
33
from .bpf_helper_handler import handle_helper_call, emit_probe_read_kernel_str_call
4-
from .helpers import ktime, pid, deref, comm, probe_read_str, XDP_DROP, XDP_PASS
4+
from .helpers import (
5+
ktime,
6+
pid,
7+
deref,
8+
comm,
9+
probe_read_str,
10+
random,
11+
probe_read,
12+
smp_processor_id,
13+
uid,
14+
skb_store_bytes,
15+
XDP_DROP,
16+
XDP_PASS,
17+
)
518

619

720
# Register the helper handler with expr module
@@ -65,6 +78,11 @@ def helper_call_handler(
6578
"deref",
6679
"comm",
6780
"probe_read_str",
81+
"random",
82+
"probe_read",
83+
"smp_processor_id",
84+
"uid",
85+
"skb_store_bytes",
6886
"XDP_DROP",
6987
"XDP_PASS",
7088
]

0 commit comments

Comments
 (0)