Skip to content

Commit 80396c7

Browse files
recursive parsing fix without ctypes in recursed type
1 parent 8774277 commit 80396c7

File tree

3 files changed

+74
-129
lines changed

3 files changed

+74
-129
lines changed

pythonbpf/vmlinux_parser/class_handler.py

Lines changed: 68 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .dependency_handler import DependencyHandler
55
from .dependency_node import DependencyNode
66
import ctypes
7+
from typing import Optional, Any
78

89
logger = logging.getLogger(__name__)
910

@@ -13,6 +14,7 @@ def get_module_symbols(module_name: str):
1314
imported_module = importlib.import_module(module_name)
1415
return [name for name in dir(imported_module)], imported_module
1516

17+
1618
def process_vmlinux_class(node, llvm_module, handler: DependencyHandler):
1719
symbols_in_module, imported_module = get_module_symbols("vmlinux")
1820
if node.name in symbols_in_module:
@@ -21,37 +23,30 @@ def process_vmlinux_class(node, llvm_module, handler: DependencyHandler):
2123
else:
2224
raise ImportError(f"{node.name} not in vmlinux")
2325

24-
# Recursive function that gets all the dependent classes and adds them to handler
25-
def process_vmlinux_post_ast(node, llvm_module, handler: DependencyHandler, processing_stack=None):
26-
"""
27-
Recursively process vmlinux classes and their dependencies.
28-
29-
Args:
30-
node: The class/type to process
31-
llvm_module: The LLVM module context
32-
handler: DependencyHandler to track all nodes
33-
processing_stack: Set of currently processing nodes to detect cycles
34-
"""
26+
27+
def process_vmlinux_post_ast(
28+
elem_type_class, llvm_handler, handler: DependencyHandler, processing_stack=None
29+
):
3530
# Initialize processing stack on first call
3631
if processing_stack is None:
3732
processing_stack = set()
38-
3933
symbols_in_module, imported_module = get_module_symbols("vmlinux")
4034

41-
# Handle both node objects and type objects
42-
if hasattr(node, "name"):
43-
current_symbol_name = node.name
44-
elif hasattr(node, "__name__"):
45-
current_symbol_name = node.__name__
46-
else:
47-
current_symbol_name = str(node)
35+
current_symbol_name = elem_type_class.__name__
36+
field_table = {}
37+
is_complex_type = False
38+
containing_type: Optional[Any] = None
39+
ctype_complex_type: Optional[Any] = None
40+
type_length: Optional[int] = None
41+
module_name = getattr(elem_type_class, "__module__", None)
4842

49-
if current_symbol_name not in symbols_in_module:
50-
raise ImportError(f"{current_symbol_name} not present in module vmlinux")
43+
if hasattr(elem_type_class, "_length_") and is_complex_type:
44+
type_length = elem_type_class._length_
5145

52-
# Check if we're already processing this node (circular dependency)
5346
if current_symbol_name in processing_stack:
54-
logger.debug(f"Circular dependency detected for {current_symbol_name}, skipping")
47+
logger.debug(
48+
f"Circular dependency detected for {current_symbol_name}, skipping"
49+
)
5550
return True
5651

5752
# Check if already processed
@@ -62,116 +57,64 @@ def process_vmlinux_post_ast(node, llvm_module, handler: DependencyHandler, proc
6257
logger.info(f"Node {current_symbol_name} already processed and ready")
6358
return True
6459

65-
logger.info(f"Resolving vmlinux class {current_symbol_name}")
66-
logger.debug(
67-
f"Current handler state: {handler.is_ready} readiness and {handler.get_all_nodes()} all nodes"
68-
)
69-
70-
# Add to processing stack to detect cycles
7160
processing_stack.add(current_symbol_name)
7261

73-
try:
74-
field_table = {} # should contain the field and it's type.
75-
76-
# Get the class object from the module
77-
class_obj = getattr(imported_module, current_symbol_name)
78-
79-
# Inspect the class fields
80-
if hasattr(class_obj, "_fields_"):
81-
for field_name, field_type in class_obj._fields_:
82-
field_table[field_name] = field_type
83-
elif hasattr(class_obj, "__annotations__"):
84-
for field_name, field_type in class_obj.__annotations__.items():
85-
field_table[field_name] = field_type
86-
else:
87-
raise TypeError("Could not get required class and definition")
88-
89-
logger.debug(f"Extracted fields for {current_symbol_name}: {field_table}")
62+
if module_name == "vmlinux":
63+
if hasattr(elem_type_class, "_type_"):
64+
is_complex_type = True
65+
containing_type = elem_type_class._type_
66+
if containing_type.__module__ == "vmlinux":
67+
print("Very weird type ig for containing type", containing_type)
68+
elif containing_type.__module__ == ctypes.__name__:
69+
if isinstance(elem_type_class, type):
70+
if issubclass(elem_type_class, ctypes.Array):
71+
ctype_complex_type = ctypes.Array
72+
elif issubclass(elem_type_class, ctypes._Pointer):
73+
ctype_complex_type = ctypes._Pointer
74+
else:
75+
raise TypeError("Unsupported ctypes subclass")
76+
# handle ctype complex type
9077

91-
# Create or get the node
92-
if handler.has_node(current_symbol_name):
93-
new_dep_node = handler.get_node(current_symbol_name)
78+
else:
79+
raise ImportError(f"Unsupported module of {containing_type}")
9480
else:
9581
new_dep_node = DependencyNode(name=current_symbol_name)
9682
handler.add_node(new_dep_node)
97-
98-
# Process each field
99-
for elem_name, elem_type in field_table.items():
100-
module_name = getattr(elem_type, "__module__", None)
101-
102-
if module_name == ctypes.__name__:
103-
# Simple ctypes - mark as ready immediately
104-
new_dep_node.add_field(elem_name, elem_type, ready=True)
105-
106-
elif module_name == "vmlinux":
107-
# Complex vmlinux type - needs recursive processing
108-
new_dep_node.add_field(elem_name, elem_type, ready=False)
109-
logger.debug(f"Processing vmlinux field: {elem_name}, type: {elem_type}")
110-
111-
identify_ctypes_type(elem_name, elem_type, new_dep_node)
112-
113-
# Determine the actual symbol to process
114-
symbol_name = (
115-
elem_type.__name__
116-
if hasattr(elem_type, "__name__")
117-
else str(elem_type)
118-
)
119-
vmlinux_symbol = None
120-
121-
# Handle pointers/arrays to other types
122-
if hasattr(elem_type, "_type_"):
123-
containing_module_name = getattr(
124-
(elem_type._type_), "__module__", None
83+
class_obj = getattr(imported_module, current_symbol_name)
84+
# Inspect the class fields
85+
if hasattr(class_obj, "_fields_"):
86+
for field_name, field_type in class_obj._fields_:
87+
field_table[field_name] = field_type
88+
elif hasattr(class_obj, "__annotations__"):
89+
for field_name, field_type in class_obj.__annotations__.items():
90+
field_table[field_name] = field_type
91+
else:
92+
raise TypeError("Could not get required class and definition")
93+
94+
logger.info(f"Extracted fields for {current_symbol_name}: {field_table}")
95+
96+
for elem_name, elem_type in field_table.items():
97+
local_module_name = getattr(elem_type, "__module__", None)
98+
if local_module_name == ctypes.__name__:
99+
new_dep_node.add_field(elem_name, elem_type, ready=True)
100+
logger.info(f"Field {elem_name} is direct ctypes type: {elem_type}")
101+
elif local_module_name == "vmlinux":
102+
new_dep_node.add_field(elem_name, elem_type, ready=False)
103+
logger.debug(
104+
f"Processing vmlinux field: {elem_name}, type: {elem_type}"
125105
)
126-
if containing_module_name == ctypes.__name__:
127-
# Pointer/Array to ctypes - mark as ready
106+
if process_vmlinux_post_ast(
107+
elem_type, llvm_handler, handler, processing_stack
108+
):
128109
new_dep_node.set_field_ready(elem_name, True)
129-
continue
130-
elif containing_module_name == "vmlinux":
131-
# Pointer/Array to vmlinux type
132-
symbol_name = (
133-
(elem_type._type_).__name__
134-
if hasattr((elem_type._type_), "__name__")
135-
else str(elem_type._type_)
136-
)
137-
138-
# Self-referential check
139-
if symbol_name == current_symbol_name:
140-
logger.debug(f"Self-referential field {elem_name} in {current_symbol_name}")
141-
# For pointers to self, we can mark as ready since the type is being defined
142-
new_dep_node.set_field_ready(elem_name, True)
143-
continue
144-
145-
vmlinux_symbol = getattr(imported_module, symbol_name)
146110
else:
147-
# Direct vmlinux type (not pointer/array)
148-
vmlinux_symbol = getattr(imported_module, symbol_name)
149-
150-
# Recursively process the dependency
151-
if vmlinux_symbol is not None:
152-
if process_vmlinux_post_ast(vmlinux_symbol, llvm_module, handler, processing_stack):
153-
new_dep_node.set_field_ready(elem_name, True)
154-
else:
155-
raise ValueError(
156-
f"{elem_name} with type {elem_type} not supported in recursive resolver"
157-
)
158-
159-
logger.info(f"Successfully processed node: {current_symbol_name}")
160-
return True
161-
162-
finally:
163-
# Remove from processing stack when done
164-
processing_stack.discard(current_symbol_name)
165-
111+
raise ValueError(
112+
f"{elem_name} with type {elem_type} from module {module_name} not supported in recursive resolver"
113+
)
114+
print("")
166115

167-
def identify_ctypes_type(elem_name, elem_type, new_dep_node: DependencyNode):
168-
if isinstance(elem_type, type):
169-
if issubclass(elem_type, ctypes.Array):
170-
new_dep_node.set_field_type(elem_name, ctypes.Array)
171-
new_dep_node.set_field_containing_type(elem_name, elem_type._type_)
172-
new_dep_node.set_field_type_size(elem_name, elem_type._length_)
173-
elif issubclass(elem_type, ctypes._Pointer):
174-
new_dep_node.set_field_type(elem_name, ctypes._Pointer)
175-
new_dep_node.set_field_containing_type(elem_name, elem_type._type_)
176116
else:
177-
raise TypeError("Instance sent instead of Class")
117+
raise ImportError("UNSUPPORTED Module")
118+
119+
print(current_symbol_name, "DONE")
120+
print(f"handler readiness {handler.is_ready}")

pythonbpf/vmlinux_parser/dependency_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from dataclasses import dataclass, field
22
from typing import Dict, Any, Optional
33

4-
#TODO: FIX THE FUCKING TYPE NAME CONVENTION.
4+
5+
# TODO: FIX THE FUCKING TYPE NAME CONVENTION.
56
@dataclass
67
class Field:
78
"""Represents a field in a dependency node with its type and readiness state."""

tests/failing_tests/xdp_pass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from pythonbpf import bpf, map, section, bpfglobal, compile, compile_to_ir
22
from pythonbpf.maps import HashMap
33
from pythonbpf.helper import XDP_PASS
4-
from vmlinux import struct_ring_buffer_per_cpu # noqa: F401
5-
from vmlinux import struct_xdp_buff # noqa: F401
64
from vmlinux import struct_xdp_md
5+
from vmlinux import struct_ring_buffer_per_cpu # noqa: F401
6+
7+
# from vmlinux import struct_xdp_buff # noqa: F401
8+
# from vmlinux import struct_xdp_md
79
from ctypes import c_int64
810

911
# Instructions to how to run this program
@@ -44,4 +46,3 @@ def LICENSE() -> str:
4446

4547

4648
compile_to_ir("xdp_pass.py", "xdp_pass.ll")
47-
compile()

0 commit comments

Comments
 (0)