|
1 | 1 | import logging |
2 | 2 | from typing import Any |
3 | | - |
| 3 | +import ctypes |
4 | 4 | from llvmlite import ir |
5 | 5 |
|
6 | 6 | from pythonbpf.local_symbol import LocalSymbol |
@@ -94,32 +94,30 @@ def handle_vmlinux_struct_field( |
94 | 94 | f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}" |
95 | 95 | ) |
96 | 96 | python_type: type = var_info.metadata |
97 | | - globvar_ir, field_data = self.get_field_type( |
98 | | - python_type.__name__, field_name |
99 | | - ) |
| 97 | + struct_name = python_type.__name__ |
| 98 | + globvar_ir, field_data = self.get_field_type(struct_name, field_name) |
100 | 99 | builder.function.args[0].type = ir.PointerType(ir.IntType(8)) |
101 | | - print(builder.function.args[0]) |
102 | 100 | field_ptr = self.load_ctx_field( |
103 | | - builder, builder.function.args[0], globvar_ir |
| 101 | + builder, builder.function.args[0], globvar_ir, field_data, struct_name |
104 | 102 | ) |
105 | | - print(field_ptr) |
106 | 103 | # Return pointer to field and field type |
107 | 104 | return field_ptr, field_data |
108 | 105 | else: |
109 | 106 | raise RuntimeError("Variable accessed not found in symbol table") |
110 | 107 |
|
111 | 108 | @staticmethod |
112 | | - def load_ctx_field(builder, ctx_arg, offset_global): |
| 109 | + def load_ctx_field(builder, ctx_arg, offset_global, field_data, struct_name=None): |
113 | 110 | """ |
114 | 111 | Generate LLVM IR to load a field from BPF context using offset. |
115 | 112 |
|
116 | 113 | Args: |
117 | 114 | builder: llvmlite IRBuilder instance |
118 | 115 | ctx_arg: The context pointer argument (ptr/i8*) |
119 | 116 | offset_global: Global variable containing the field offset (i64) |
120 | | -
|
| 117 | + field_data: contains data about the field |
| 118 | + struct_name: Name of the struct being accessed (optional) |
121 | 119 | Returns: |
122 | | - The loaded value (i64 register) |
| 120 | + The loaded value (i64 register or appropriately sized) |
123 | 121 | """ |
124 | 122 |
|
125 | 123 | # Load the offset value |
@@ -164,13 +162,61 @@ def load_ctx_field(builder, ctx_arg, offset_global): |
164 | 162 | passthrough_fn, [ir.Constant(ir.IntType(32), 0), field_ptr], tail=True |
165 | 163 | ) |
166 | 164 |
|
167 | | - # Bitcast to i64* (assuming field is 64-bit, adjust if needed) |
168 | | - i64_ptr_type = ir.PointerType(ir.IntType(64)) |
169 | | - typed_ptr = builder.bitcast(verified_ptr, i64_ptr_type) |
| 165 | + # Determine the appropriate IR type based on field information |
| 166 | + int_width = 64 # Default to 64-bit |
| 167 | + needs_zext = False # Track if we need zero-extension for xdp_md |
| 168 | + |
| 169 | + if field_data is not None: |
| 170 | + # Try to determine the size from field metadata |
| 171 | + if field_data.type.__module__ == ctypes.__name__: |
| 172 | + try: |
| 173 | + field_size_bytes = ctypes.sizeof(field_data.type) |
| 174 | + field_size_bits = field_size_bytes * 8 |
| 175 | + |
| 176 | + if field_size_bits in [8, 16, 32, 64]: |
| 177 | + int_width = field_size_bits |
| 178 | + logger.info(f"Determined field size: {int_width} bits") |
| 179 | + |
| 180 | + # Special handling for struct_xdp_md i32 fields |
| 181 | + # Load as i32 but extend to i64 before storing |
| 182 | + if struct_name == "struct_xdp_md" and int_width == 32: |
| 183 | + needs_zext = True |
| 184 | + logger.info( |
| 185 | + "struct_xdp_md i32 field detected, will zero-extend to i64" |
| 186 | + ) |
| 187 | + else: |
| 188 | + logger.warning( |
| 189 | + f"Unusual field size {field_size_bits} bits, using default 64" |
| 190 | + ) |
| 191 | + except Exception as e: |
| 192 | + logger.warning( |
| 193 | + f"Could not determine field size: {e}, using default 64" |
| 194 | + ) |
| 195 | + |
| 196 | + elif field_data.type.__module__ == "vmlinux": |
| 197 | + # For pointers to structs or complex vmlinux types |
| 198 | + if field_data.ctype_complex_type is not None and issubclass( |
| 199 | + field_data.ctype_complex_type, ctypes._Pointer |
| 200 | + ): |
| 201 | + int_width = 64 # Pointers are always 64-bit |
| 202 | + logger.info("Field is a pointer type, using 64 bits") |
| 203 | + # TODO: Add handling for other complex types (arrays, embedded structs, etc.) |
| 204 | + else: |
| 205 | + logger.warning("Complex vmlinux field type, using default 64 bits") |
| 206 | + |
| 207 | + # Bitcast to appropriate pointer type based on determined width |
| 208 | + ptr_type = ir.PointerType(ir.IntType(int_width)) |
| 209 | + |
| 210 | + typed_ptr = builder.bitcast(verified_ptr, ptr_type) |
170 | 211 |
|
171 | 212 | # Load and return the value |
172 | 213 | value = builder.load(typed_ptr) |
173 | 214 |
|
| 215 | + # Zero-extend i32 to i64 for struct_xdp_md fields |
| 216 | + if needs_zext: |
| 217 | + value = builder.zext(value, ir.IntType(64)) |
| 218 | + logger.info("Zero-extended i32 value to i64 for struct_xdp_md field") |
| 219 | + |
174 | 220 | return value |
175 | 221 |
|
176 | 222 | def has_field(self, struct_name, field_name): |
|
0 commit comments