Skip to content

Commit 2cf68f6

Browse files
committed
Allow map-based helpers to be used as helper args / within binops which are helper args
1 parent d66e6a6 commit 2cf68f6

File tree

4 files changed

+56
-18
lines changed

4 files changed

+56
-18
lines changed

pythonbpf/binary_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def get_operand_value(
3939
if res is None:
4040
raise ValueError(f"Failed to evaluate call expression: {operand}")
4141
val, _ = res
42+
logger.info(f"Evaluated expr to {val} of type {val.type}")
43+
base_type, depth = get_base_type_and_depth(val.type)
44+
if depth > 0:
45+
val = deref_to_depth(func, builder, val, depth)
4246
return val
4347
raise TypeError(f"Unsupported operand type: {type(operand)}")
4448

pythonbpf/functions/functions_pass.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -388,14 +388,18 @@ def process_stmt(
388388
return did_return
389389

390390

391-
def count_temps_in_call(call_node):
391+
def count_temps_in_call(call_node, local_sym_tab):
392392
"""Count the number of temporary variables needed for a function call."""
393393

394394
count = 0
395395
is_helper = False
396396

397+
# NOTE: We exclude print calls for now
397398
if isinstance(call_node.func, ast.Name):
398-
if HelperHandlerRegistry.has_handler(call_node.func.id):
399+
if (
400+
HelperHandlerRegistry.has_handler(call_node.func.id)
401+
and call_node.func.id != "print"
402+
):
399403
is_helper = True
400404
elif isinstance(call_node.func, ast.Attribute):
401405
if HelperHandlerRegistry.has_handler(call_node.func.attr):
@@ -405,10 +409,11 @@ def count_temps_in_call(call_node):
405409
return 0
406410

407411
for arg in call_node.args:
408-
if (
409-
isinstance(arg, ast.BinOp)
410-
or isinstance(arg, ast.Constant)
411-
or isinstance(arg, ast.UnaryOp)
412+
# NOTE: Count all non-name arguments
413+
# For struct fields, if it is being passed as an argument,
414+
# The struct object should already exist in the local_sym_tab
415+
if not isinstance(arg, ast.Name) and not (
416+
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
412417
):
413418
count += 1
414419

@@ -423,11 +428,19 @@ def allocate_mem(
423428

424429
def update_max_temps_for_stmt(stmt):
425430
nonlocal max_temps_needed
431+
temps_needed = 0
432+
433+
if isinstance(stmt, ast.If):
434+
for s in stmt.body:
435+
update_max_temps_for_stmt(s)
436+
for s in stmt.orelse:
437+
update_max_temps_for_stmt(s)
438+
return
426439

427440
for node in ast.walk(stmt):
428441
if isinstance(node, ast.Call):
429-
temps_needed = count_temps_in_call(node)
430-
max_temps_needed = max(max_temps_needed, temps_needed)
442+
temps_needed += count_temps_in_call(node, local_sym_tab)
443+
max_temps_needed = max(max_temps_needed, temps_needed)
431444

432445
for stmt in body:
433446
update_max_temps_for_stmt(stmt)
@@ -460,9 +473,16 @@ def update_max_temps_for_stmt(stmt):
460473
logger.info("Unsupported multiassignment")
461474
continue
462475
target = stmt.targets[0]
463-
if not isinstance(target, ast.Name):
476+
if not isinstance(target, ast.Name) and not isinstance(
477+
target, ast.Attribute
478+
):
464479
logger.info("Unsupported assignment target")
465480
continue
481+
if isinstance(target, ast.Attribute):
482+
logger.info(
483+
f"Struct field {target.attr} assignment, will be handled later"
484+
)
485+
continue
466486
var_name = target.id
467487
rval = stmt.value
468488
if var_name in local_sym_tab:

pythonbpf/helper/bpf_helper_handler.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def bpf_ktime_get_ns_emitter(
3434
func,
3535
local_sym_tab=None,
3636
struct_sym_tab=None,
37+
map_sym_tab=None,
3738
):
3839
"""
3940
Emit LLVM IR for bpf_ktime_get_ns helper function call.
@@ -56,6 +57,7 @@ def bpf_map_lookup_elem_emitter(
5657
func,
5758
local_sym_tab=None,
5859
struct_sym_tab=None,
60+
map_sym_tab=None,
5961
):
6062
"""
6163
Emit LLVM IR for bpf_map_lookup_elem helper function call.
@@ -65,12 +67,16 @@ def bpf_map_lookup_elem_emitter(
6567
f"Map lookup expects exactly one argument (key), got {len(call.args)}"
6668
)
6769
key_ptr = get_or_create_ptr_from_arg(
68-
func, module, call.args[0], builder, local_sym_tab, struct_sym_tab
70+
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
6971
)
7072
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
7173

74+
# TODO: I have changed the return typr to i64*, as we are
75+
# allocating space for that type in allocate_mem. This is
76+
# temporary, and we will honour other widths later. But this
77+
# allows us to have cool binary ops on the returned value.
7278
fn_type = ir.FunctionType(
73-
ir.PointerType(), # Return type: void*
79+
ir.PointerType(ir.IntType(64)), # Return type: void*
7480
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*)
7581
var_arg=False,
7682
)
@@ -93,6 +99,7 @@ def bpf_printk_emitter(
9399
func,
94100
local_sym_tab=None,
95101
struct_sym_tab=None,
102+
map_sym_tab=None,
96103
):
97104
"""Emit LLVM IR for bpf_printk helper function call."""
98105
if not hasattr(func, "_fmt_counter"):
@@ -140,6 +147,7 @@ def bpf_map_update_elem_emitter(
140147
func,
141148
local_sym_tab=None,
142149
struct_sym_tab=None,
150+
map_sym_tab=None,
143151
):
144152
"""
145153
Emit LLVM IR for bpf_map_update_elem helper function call.
@@ -155,10 +163,10 @@ def bpf_map_update_elem_emitter(
155163
flags_arg = call.args[2] if len(call.args) > 2 else None
156164

157165
key_ptr = get_or_create_ptr_from_arg(
158-
func, module, key_arg, builder, local_sym_tab, struct_sym_tab
166+
func, module, key_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
159167
)
160168
value_ptr = get_or_create_ptr_from_arg(
161-
func, module, value_arg, builder, local_sym_tab, struct_sym_tab
169+
func, module, value_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
162170
)
163171
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
164172

@@ -194,6 +202,7 @@ def bpf_map_delete_elem_emitter(
194202
func,
195203
local_sym_tab=None,
196204
struct_sym_tab=None,
205+
map_sym_tab=None,
197206
):
198207
"""
199208
Emit LLVM IR for bpf_map_delete_elem helper function call.
@@ -204,7 +213,7 @@ def bpf_map_delete_elem_emitter(
204213
f"Map delete expects exactly one argument (key), got {len(call.args)}"
205214
)
206215
key_ptr = get_or_create_ptr_from_arg(
207-
func, module, call.args[0], builder, local_sym_tab, struct_sym_tab
216+
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
208217
)
209218
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
210219

@@ -233,6 +242,7 @@ def bpf_get_current_pid_tgid_emitter(
233242
func,
234243
local_sym_tab=None,
235244
struct_sym_tab=None,
245+
map_sym_tab=None,
236246
):
237247
"""
238248
Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
@@ -259,6 +269,7 @@ def bpf_perf_event_output_handler(
259269
func,
260270
local_sym_tab=None,
261271
struct_sym_tab=None,
272+
map_sym_tab=None,
262273
):
263274
if len(call.args) != 1:
264275
raise ValueError(
@@ -323,6 +334,7 @@ def invoke_helper(method_name, map_ptr=None):
323334
func,
324335
local_sym_tab,
325336
struct_sym_tab,
337+
map_sym_tab,
326338
)
327339

328340
# Handle direct function calls (e.g., print(), ktime())

pythonbpf/helper/helper_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,14 @@ def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
8181

8282
# Default to 64-bit integer
8383
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
84-
logger.debug(f"Using temp variable '{temp_name}' for int constant {value}")
84+
logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
8585
const_val = ir.Constant(ir.IntType(int_width), value)
8686
builder.store(const_val, ptr)
8787
return ptr
8888

8989

9090
def get_or_create_ptr_from_arg(
91-
func, module, arg, builder, local_sym_tab, struct_sym_tab=None
91+
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None
9292
):
9393
"""Extract or create pointer from the call arguments."""
9494

@@ -104,15 +104,17 @@ def get_or_create_ptr_from_arg(
104104
builder,
105105
arg,
106106
local_sym_tab,
107-
None,
107+
map_sym_tab,
108108
struct_sym_tab,
109109
)
110110
if val is None:
111111
raise ValueError("Failed to evaluate expression for helper arg.")
112112

113113
# NOTE: We assume the result is an int64 for now
114+
# if isinstance(arg, ast.Attribute):
115+
# return val
114116
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
115-
logger.debug(f"Using temp variable '{temp_name}' for expression result")
117+
logger.info(f"Using temp variable '{temp_name}' for expression result")
116118
builder.store(val, ptr)
117119

118120
return ptr

0 commit comments

Comments
 (0)