Skip to content

Commit 1a0e21e

Browse files
support vmlinux enum in map arguments
1 parent 190baf2 commit 1a0e21e

File tree

4 files changed

+35
-6
lines changed

4 files changed

+35
-6
lines changed

pythonbpf/expr/expr_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def _handle_unary_op(
349349
neg_one = ir.Constant(ir.IntType(64), -1)
350350
result = builder.mul(operand, neg_one)
351351
return result, ir.IntType(64)
352+
return None
352353

353354

354355
# ============================================================================

pythonbpf/maps/maps_pass.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from .maps_utils import MapProcessorRegistry
77
from .map_types import BPFMapType
88
from .map_debug_info import create_map_debug_info, create_ringbuf_debug_info
9+
from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
10+
911

1012
logger: Logger = logging.getLogger(__name__)
1113

@@ -51,7 +53,7 @@ def _parse_map_params(rval, expected_args=None):
5153
"""Parse map parameters from call arguments and keywords."""
5254

5355
params = {}
54-
56+
handler = VmlinuxHandlerRegistry.get_handler()
5557
# Parse positional arguments
5658
if expected_args:
5759
for i, arg_name in enumerate(expected_args):
@@ -65,7 +67,12 @@ def _parse_map_params(rval, expected_args=None):
6567
# Parse keyword arguments (override positional)
6668
for keyword in rval.keywords:
6769
if isinstance(keyword.value, ast.Name):
68-
params[keyword.arg] = keyword.value.id
70+
name = keyword.value.id
71+
if handler and handler.is_vmlinux_enum(name):
72+
result = handler.get_vmlinux_enum_value(name)
73+
params[keyword.arg] = result if result is not None else name
74+
else:
75+
params[keyword.arg] = name
6976
elif isinstance(keyword.value, ast.Constant):
7077
params[keyword.arg] = keyword.value.value
7178

pythonbpf/vmlinux_parser/vmlinux_exports_handler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def handle_vmlinux_enum(self, name):
5454
return ir.Constant(ir.IntType(64), value), ir.IntType(64)
5555
return None
5656

57+
def get_vmlinux_enum_value(self, name):
58+
"""Handle vmlinux enum constants by returning LLVM IR constants"""
59+
if self.is_vmlinux_enum(name):
60+
value = self.vmlinux_symtab[name]["value"]
61+
logger.info(f"The value of vmlinux enum {name} = {value}")
62+
return value
63+
return None
64+
5765
def handle_vmlinux_struct(self, struct_name, module, builder):
5866
"""Handle vmlinux struct initializations"""
5967
if self.is_vmlinux_struct(struct_name):

tests/passing_tests/vmlinux/simple_struct_test.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
11
import logging
22

3-
from pythonbpf import bpf, section, bpfglobal, compile_to_ir
3+
from pythonbpf import bpf, section, bpfglobal, compile_to_ir, map
44
from pythonbpf import compile # noqa: F401
55
from vmlinux import TASK_COMM_LEN # noqa: F401
66
from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401
7+
from ctypes import c_uint64, c_int32, c_int64
8+
from pythonbpf.maps import HashMap
79

810
# from vmlinux import struct_uinput_device
911
# from vmlinux import struct_blk_integrity_iter
10-
from ctypes import c_int64
12+
13+
14+
@bpf
15+
@map
16+
def mymap() -> HashMap:
17+
return HashMap(key=c_int32, value=c_uint64, max_entries=TASK_COMM_LEN)
18+
19+
20+
@bpf
21+
@map
22+
def mymap2() -> HashMap:
23+
return HashMap(key=c_int32, value=c_uint64, max_entries=18)
1124

1225

1326
# Instructions to how to run this program
@@ -21,7 +34,7 @@
2134
def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64:
2235
a = 2 + TASK_COMM_LEN + TASK_COMM_LEN
2336
print(f"Hello, World{TASK_COMM_LEN} and {a}")
24-
return c_int64(TASK_COMM_LEN)
37+
return c_int64(TASK_COMM_LEN + 2)
2538

2639

2740
@bpf
@@ -31,4 +44,4 @@ def LICENSE() -> str:
3144

3245

3346
compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG)
34-
compile()
47+
# compile()

0 commit comments

Comments
 (0)