|
1 | 1 | import logging |
2 | 2 | from typing import Any |
3 | | - |
| 3 | +import ctypes |
4 | 4 | from llvmlite import ir |
5 | 5 |
|
6 | 6 | from pythonbpf.local_symbol import LocalSymbol |
@@ -98,26 +98,24 @@ def handle_vmlinux_struct_field( |
98 | 98 | python_type.__name__, field_name |
99 | 99 | ) |
100 | 100 | builder.function.args[0].type = ir.PointerType(ir.IntType(8)) |
101 | | - print(builder.function.args[0]) |
102 | 101 | field_ptr = self.load_ctx_field( |
103 | | - builder, builder.function.args[0], globvar_ir |
| 102 | + builder, builder.function.args[0], globvar_ir, field_data |
104 | 103 | ) |
105 | | - print(field_ptr) |
106 | 104 | # Return pointer to field and field type |
107 | 105 | return field_ptr, field_data |
108 | 106 | else: |
109 | 107 | raise RuntimeError("Variable accessed not found in symbol table") |
110 | 108 |
|
111 | 109 | @staticmethod |
112 | | - def load_ctx_field(builder, ctx_arg, offset_global): |
| 110 | + def load_ctx_field(builder, ctx_arg, offset_global, field_data): |
113 | 111 | """ |
114 | 112 | Generate LLVM IR to load a field from BPF context using offset. |
115 | 113 |
|
116 | 114 | Args: |
117 | 115 | builder: llvmlite IRBuilder instance |
118 | 116 | ctx_arg: The context pointer argument (ptr/i8*) |
119 | 117 | offset_global: Global variable containing the field offset (i64) |
120 | | -
|
| 118 | + field_data: contains data about the field |
121 | 119 | Returns: |
122 | 120 | The loaded value (i64 register) |
123 | 121 | """ |
@@ -164,9 +162,43 @@ def load_ctx_field(builder, ctx_arg, offset_global): |
164 | 162 | passthrough_fn, [ir.Constant(ir.IntType(32), 0), field_ptr], tail=True |
165 | 163 | ) |
166 | 164 |
|
167 | | - # Bitcast to i64* (assuming field is 64-bit, adjust if needed) |
168 | | - i64_ptr_type = ir.PointerType(ir.IntType(64)) |
169 | | - typed_ptr = builder.bitcast(verified_ptr, i64_ptr_type) |
| 165 | + # Determine the appropriate IR type based on field information |
| 166 | + int_width = 64 # Default to 64-bit |
| 167 | + |
| 168 | + if field_data is not None: |
| 169 | + # Try to determine the size from field metadata |
| 170 | + if field_data.type.__module__ == ctypes.__name__: |
| 171 | + try: |
| 172 | + field_size_bytes = ctypes.sizeof(field_data.type) |
| 173 | + field_size_bits = field_size_bytes * 8 |
| 174 | + |
| 175 | + if field_size_bits in [8, 16, 32, 64]: |
| 176 | + int_width = field_size_bits |
| 177 | + logger.info(f"Determined field size: {int_width} bits") |
| 178 | + else: |
| 179 | + logger.warning( |
| 180 | + f"Unusual field size {field_size_bits} bits, using default 64" |
| 181 | + ) |
| 182 | + except Exception as e: |
| 183 | + logger.warning( |
| 184 | + f"Could not determine field size: {e}, using default 64" |
| 185 | + ) |
| 186 | + |
| 187 | + elif field_data.type.__module__ == "vmlinux": |
| 188 | + # For pointers to structs or complex vmlinux types |
| 189 | + if field_data.ctype_complex_type is not None and issubclass( |
| 190 | + field_data.ctype_complex_type, ctypes._Pointer |
| 191 | + ): |
| 192 | + int_width = 64 # Pointers are always 64-bit |
| 193 | + logger.info("Field is a pointer type, using 64 bits") |
| 194 | + # TODO: Add handling for other complex types (arrays, embedded structs, etc.) |
| 195 | + else: |
| 196 | + logger.warning("Complex vmlinux field type, using default 64 bits") |
| 197 | + |
| 198 | + # Bitcast to appropriate pointer type based on determined width |
| 199 | + ptr_type = ir.PointerType(ir.IntType(int_width)) |
| 200 | + |
| 201 | + typed_ptr = builder.bitcast(verified_ptr, ptr_type) |
170 | 202 |
|
171 | 203 | # Load and return the value |
172 | 204 | value = builder.load(typed_ptr) |
|
0 commit comments