|
5 | 5 | from typing import Dict |
6 | 6 |
|
7 | 7 | from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes |
| 8 | +from .call_registry import CallHandlerRegistry |
8 | 9 | from .type_normalization import ( |
9 | 10 | convert_to_bool, |
10 | 11 | handle_comparator, |
|
14 | 15 |
|
15 | 16 | logger: Logger = logging.getLogger(__name__) |
16 | 17 |
|
| 18 | +# ============================================================================ |
| 19 | +# Leaf Handlers (No Recursive eval_expr calls) |
| 20 | +# ============================================================================ |
17 | 21 |
|
18 | | -class CallHandlerRegistry: |
19 | | - """Registry for handling different types of calls (helpers, etc.)""" |
20 | 22 |
|
21 | | - _handler = None |
| 23 | +def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder): |
| 24 | + """Handle ast.Name expressions.""" |
| 25 | + if expr.id in local_sym_tab: |
| 26 | + var = local_sym_tab[expr.id].var |
| 27 | + val = builder.load(var) |
| 28 | + return val, local_sym_tab[expr.id].ir_type |
| 29 | + else: |
| 30 | + logger.info(f"Undefined variable {expr.id}") |
| 31 | + return None |
| 32 | + |
| 33 | + |
| 34 | +def _handle_constant_expr(module, builder, expr: ast.Constant): |
| 35 | + """Handle ast.Constant expressions.""" |
| 36 | + if isinstance(expr.value, int) or isinstance(expr.value, bool): |
| 37 | + return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64) |
| 38 | + elif isinstance(expr.value, str): |
| 39 | + str_name = f".str.{id(expr)}" |
| 40 | + str_bytes = expr.value.encode("utf-8") + b"\x00" |
| 41 | + str_type = ir.ArrayType(ir.IntType(8), len(str_bytes)) |
| 42 | + str_constant = ir.Constant(str_type, bytearray(str_bytes)) |
| 43 | + |
| 44 | + # Create global variable |
| 45 | + global_str = ir.GlobalVariable(module, str_type, name=str_name) |
| 46 | + global_str.linkage = "internal" |
| 47 | + global_str.global_constant = True |
| 48 | + global_str.initializer = str_constant |
| 49 | + |
| 50 | + str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8))) |
| 51 | + return str_ptr, ir.PointerType(ir.IntType(8)) |
| 52 | + else: |
| 53 | + logger.error(f"Unsupported constant type {ast.dump(expr)}") |
| 54 | + return None |
| 55 | + |
| 56 | + |
| 57 | +def _handle_attribute_expr( |
| 58 | + expr: ast.Attribute, |
| 59 | + local_sym_tab: Dict, |
| 60 | + structs_sym_tab: Dict, |
| 61 | + builder: ir.IRBuilder, |
| 62 | +): |
| 63 | + """Handle ast.Attribute expressions for struct field access.""" |
| 64 | + if isinstance(expr.value, ast.Name): |
| 65 | + var_name = expr.value.id |
| 66 | + attr_name = expr.attr |
| 67 | + if var_name in local_sym_tab: |
| 68 | + var_ptr, var_type, var_metadata = local_sym_tab[var_name] |
| 69 | + logger.info(f"Loading attribute {attr_name} from variable {var_name}") |
| 70 | + logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}") |
| 71 | + metadata = structs_sym_tab[var_metadata] |
| 72 | + if attr_name in metadata.fields: |
| 73 | + gep = metadata.gep(builder, var_ptr, attr_name) |
| 74 | + val = builder.load(gep) |
| 75 | + field_type = metadata.field_type(attr_name) |
| 76 | + return val, field_type |
| 77 | + return None |
22 | 78 |
|
23 | | - @classmethod |
24 | | - def set_handler(cls, handler): |
25 | | - """Set the handler for unknown calls""" |
26 | | - cls._handler = handler |
27 | 79 |
|
28 | | - @classmethod |
29 | | - def handle_call( |
30 | | - cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab |
| 80 | +def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder): |
| 81 | + """Handle deref function calls.""" |
| 82 | + logger.info(f"Handling deref {ast.dump(expr)}") |
| 83 | + if len(expr.args) != 1: |
| 84 | + logger.info("deref takes exactly one argument") |
| 85 | + return None |
| 86 | + |
| 87 | + arg = expr.args[0] |
| 88 | + if ( |
| 89 | + isinstance(arg, ast.Call) |
| 90 | + and isinstance(arg.func, ast.Name) |
| 91 | + and arg.func.id == "deref" |
31 | 92 | ): |
32 | | - """Handle a call using the registered handler""" |
33 | | - if cls._handler is None: |
| 93 | + logger.info("Multiple deref not supported") |
| 94 | + return None |
| 95 | + |
| 96 | + if isinstance(arg, ast.Name): |
| 97 | + if arg.id in local_sym_tab: |
| 98 | + arg_ptr = local_sym_tab[arg.id].var |
| 99 | + else: |
| 100 | + logger.info(f"Undefined variable {arg.id}") |
34 | 101 | return None |
35 | | - return cls._handler( |
36 | | - call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab |
37 | | - ) |
| 102 | + else: |
| 103 | + logger.info("Unsupported argument type for deref") |
| 104 | + return None |
| 105 | + |
| 106 | + if arg_ptr is None: |
| 107 | + logger.info("Failed to evaluate deref argument") |
| 108 | + return None |
| 109 | + |
| 110 | + # Load the value from pointer |
| 111 | + val = builder.load(arg_ptr) |
| 112 | + return val, local_sym_tab[arg.id].ir_type |
| 113 | + |
| 114 | + |
| 115 | +# ============================================================================ |
| 116 | +# Binary Operations |
| 117 | +# ============================================================================ |
38 | 118 |
|
39 | 119 |
|
40 | 120 | def get_operand_value( |
@@ -139,96 +219,9 @@ def _handle_binary_op( |
139 | 219 | return result, result.type |
140 | 220 |
|
141 | 221 |
|
142 | | -def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder): |
143 | | - """Handle ast.Name expressions.""" |
144 | | - if expr.id in local_sym_tab: |
145 | | - var = local_sym_tab[expr.id].var |
146 | | - val = builder.load(var) |
147 | | - return val, local_sym_tab[expr.id].ir_type |
148 | | - else: |
149 | | - logger.info(f"Undefined variable {expr.id}") |
150 | | - return None |
151 | | - |
152 | | - |
153 | | -def _handle_constant_expr(module, builder, expr: ast.Constant): |
154 | | - """Handle ast.Constant expressions.""" |
155 | | - if isinstance(expr.value, int) or isinstance(expr.value, bool): |
156 | | - return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64) |
157 | | - elif isinstance(expr.value, str): |
158 | | - str_name = f".str.{id(expr)}" |
159 | | - str_bytes = expr.value.encode("utf-8") + b"\x00" |
160 | | - str_type = ir.ArrayType(ir.IntType(8), len(str_bytes)) |
161 | | - str_constant = ir.Constant(str_type, bytearray(str_bytes)) |
162 | | - |
163 | | - # Create global variable |
164 | | - global_str = ir.GlobalVariable(module, str_type, name=str_name) |
165 | | - global_str.linkage = "internal" |
166 | | - global_str.global_constant = True |
167 | | - global_str.initializer = str_constant |
168 | | - |
169 | | - str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8))) |
170 | | - return str_ptr, ir.PointerType(ir.IntType(8)) |
171 | | - else: |
172 | | - logger.error(f"Unsupported constant type {ast.dump(expr)}") |
173 | | - return None |
174 | | - |
175 | | - |
176 | | -def _handle_attribute_expr( |
177 | | - expr: ast.Attribute, |
178 | | - local_sym_tab: Dict, |
179 | | - structs_sym_tab: Dict, |
180 | | - builder: ir.IRBuilder, |
181 | | -): |
182 | | - """Handle ast.Attribute expressions for struct field access.""" |
183 | | - if isinstance(expr.value, ast.Name): |
184 | | - var_name = expr.value.id |
185 | | - attr_name = expr.attr |
186 | | - if var_name in local_sym_tab: |
187 | | - var_ptr, var_type, var_metadata = local_sym_tab[var_name] |
188 | | - logger.info(f"Loading attribute {attr_name} from variable {var_name}") |
189 | | - logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}") |
190 | | - metadata = structs_sym_tab[var_metadata] |
191 | | - if attr_name in metadata.fields: |
192 | | - gep = metadata.gep(builder, var_ptr, attr_name) |
193 | | - val = builder.load(gep) |
194 | | - field_type = metadata.field_type(attr_name) |
195 | | - return val, field_type |
196 | | - return None |
197 | | - |
198 | | - |
199 | | -def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder): |
200 | | - """Handle deref function calls.""" |
201 | | - logger.info(f"Handling deref {ast.dump(expr)}") |
202 | | - if len(expr.args) != 1: |
203 | | - logger.info("deref takes exactly one argument") |
204 | | - return None |
205 | | - |
206 | | - arg = expr.args[0] |
207 | | - if ( |
208 | | - isinstance(arg, ast.Call) |
209 | | - and isinstance(arg.func, ast.Name) |
210 | | - and arg.func.id == "deref" |
211 | | - ): |
212 | | - logger.info("Multiple deref not supported") |
213 | | - return None |
214 | | - |
215 | | - if isinstance(arg, ast.Name): |
216 | | - if arg.id in local_sym_tab: |
217 | | - arg_ptr = local_sym_tab[arg.id].var |
218 | | - else: |
219 | | - logger.info(f"Undefined variable {arg.id}") |
220 | | - return None |
221 | | - else: |
222 | | - logger.info("Unsupported argument type for deref") |
223 | | - return None |
224 | | - |
225 | | - if arg_ptr is None: |
226 | | - logger.info("Failed to evaluate deref argument") |
227 | | - return None |
228 | | - |
229 | | - # Load the value from pointer |
230 | | - val = builder.load(arg_ptr) |
231 | | - return val, local_sym_tab[arg.id].ir_type |
| 222 | +# ============================================================================ |
| 223 | +# Comparison and Unary Operations |
| 224 | +# ============================================================================ |
232 | 225 |
|
233 | 226 |
|
234 | 227 | def _handle_ctypes_call( |
@@ -341,6 +334,11 @@ def _handle_unary_op( |
341 | 334 | return result, ir.IntType(64) |
342 | 335 |
|
343 | 336 |
|
| 337 | +# ============================================================================ |
| 338 | +# Boolean Operations |
| 339 | +# ============================================================================ |
| 340 | + |
| 341 | + |
344 | 342 | def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): |
345 | 343 | """Handle `and` boolean operations.""" |
346 | 344 |
|
@@ -471,6 +469,11 @@ def _handle_boolean_op( |
471 | 469 | return None |
472 | 470 |
|
473 | 471 |
|
| 472 | +# ============================================================================ |
| 473 | +# Expression Dispatcher |
| 474 | +# ============================================================================ |
| 475 | + |
| 476 | + |
474 | 477 | def eval_expr( |
475 | 478 | func, |
476 | 479 | module, |
|
0 commit comments