|
5 | 5 | from typing import Dict |
6 | 6 |
|
7 | 7 | from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes |
8 | | -from .type_normalization import convert_to_bool, handle_comparator |
| 8 | +from .type_normalization import ( |
| 9 | + convert_to_bool, |
| 10 | + handle_comparator, |
| 11 | + get_base_type_and_depth, |
| 12 | + deref_to_depth, |
| 13 | +) |
9 | 14 |
|
10 | 15 | logger: Logger = logging.getLogger(__name__) |
11 | 16 |
|
12 | 17 |
|
| 18 | +def get_operand_value( |
| 19 | + func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None |
| 20 | +): |
| 21 | + """Extract the value from an operand, handling variables and constants.""" |
| 22 | + logger.info(f"Getting operand value for: {ast.dump(operand)}") |
| 23 | + if isinstance(operand, ast.Name): |
| 24 | + if operand.id in local_sym_tab: |
| 25 | + var = local_sym_tab[operand.id].var |
| 26 | + var_type = var.type |
| 27 | + base_type, depth = get_base_type_and_depth(var_type) |
| 28 | + logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}") |
| 29 | + val = deref_to_depth(func, builder, var, depth) |
| 30 | + return val |
| 31 | + raise ValueError(f"Undefined variable: {operand.id}") |
| 32 | + elif isinstance(operand, ast.Constant): |
| 33 | + if isinstance(operand.value, int): |
| 34 | + cst = ir.Constant(ir.IntType(64), int(operand.value)) |
| 35 | + return cst |
| 36 | + raise TypeError(f"Unsupported constant type: {type(operand.value)}") |
| 37 | + elif isinstance(operand, ast.BinOp): |
| 38 | + res = _handle_binary_op_impl( |
| 39 | + func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab |
| 40 | + ) |
| 41 | + return res |
| 42 | + else: |
| 43 | + res = eval_expr( |
| 44 | + func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab |
| 45 | + ) |
| 46 | + if res is None: |
| 47 | + raise ValueError(f"Failed to evaluate call expression: {operand}") |
| 48 | + val, _ = res |
| 49 | + logger.info(f"Evaluated expr to {val} of type {val.type}") |
| 50 | + base_type, depth = get_base_type_and_depth(val.type) |
| 51 | + if depth > 0: |
| 52 | + val = deref_to_depth(func, builder, val, depth) |
| 53 | + return val |
| 54 | + raise TypeError(f"Unsupported operand type: {type(operand)}") |
| 55 | + |
| 56 | + |
| 57 | +def _handle_binary_op_impl( |
| 58 | + func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None |
| 59 | +): |
| 60 | + op = rval.op |
| 61 | + left = get_operand_value( |
| 62 | + func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab |
| 63 | + ) |
| 64 | + right = get_operand_value( |
| 65 | + func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab |
| 66 | + ) |
| 67 | + logger.info(f"left is {left}, right is {right}, op is {op}") |
| 68 | + |
| 69 | + # NOTE: Before doing the operation, if the operands are integers |
| 70 | + # we always extend them to i64. The assignment to LHS will take |
| 71 | + # care of truncation if needed. |
| 72 | + if isinstance(left.type, ir.IntType) and left.type.width < 64: |
| 73 | + left = builder.sext(left, ir.IntType(64)) |
| 74 | + if isinstance(right.type, ir.IntType) and right.type.width < 64: |
| 75 | + right = builder.sext(right, ir.IntType(64)) |
| 76 | + |
| 77 | + # Map AST operation nodes to LLVM IR builder methods |
| 78 | + op_map = { |
| 79 | + ast.Add: builder.add, |
| 80 | + ast.Sub: builder.sub, |
| 81 | + ast.Mult: builder.mul, |
| 82 | + ast.Div: builder.sdiv, |
| 83 | + ast.Mod: builder.srem, |
| 84 | + ast.LShift: builder.shl, |
| 85 | + ast.RShift: builder.lshr, |
| 86 | + ast.BitOr: builder.or_, |
| 87 | + ast.BitXor: builder.xor, |
| 88 | + ast.BitAnd: builder.and_, |
| 89 | + ast.FloorDiv: builder.udiv, |
| 90 | + } |
| 91 | + |
| 92 | + if type(op) in op_map: |
| 93 | + result = op_map[type(op)](left, right) |
| 94 | + return result |
| 95 | + else: |
| 96 | + raise SyntaxError("Unsupported binary operation") |
| 97 | + |
| 98 | + |
| 99 | +def _handle_binary_op( |
| 100 | + func, |
| 101 | + module, |
| 102 | + rval, |
| 103 | + builder, |
| 104 | + var_name, |
| 105 | + local_sym_tab, |
| 106 | + map_sym_tab, |
| 107 | + structs_sym_tab=None, |
| 108 | +): |
| 109 | + result = _handle_binary_op_impl( |
| 110 | + func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab |
| 111 | + ) |
| 112 | + if var_name and var_name in local_sym_tab: |
| 113 | + logger.info( |
| 114 | + f"Storing result {result} into variable {local_sym_tab[var_name].var}" |
| 115 | + ) |
| 116 | + builder.store(result, local_sym_tab[var_name].var) |
| 117 | + return result, result.type |
| 118 | + |
| 119 | + |
13 | 120 | def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder): |
14 | 121 | """Handle ast.Name expressions.""" |
15 | 122 | if expr.id in local_sym_tab: |
@@ -194,8 +301,6 @@ def _handle_unary_op( |
194 | 301 | logger.error("Only 'not' and '-' unary operators are supported") |
195 | 302 | return None |
196 | 303 |
|
197 | | - from pythonbpf.binary_ops import get_operand_value |
198 | | - |
199 | 304 | operand = get_operand_value( |
200 | 305 | func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab |
201 | 306 | ) |
@@ -421,9 +526,7 @@ def eval_expr( |
421 | 526 | elif isinstance(expr, ast.Attribute): |
422 | 527 | return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder) |
423 | 528 | elif isinstance(expr, ast.BinOp): |
424 | | - from pythonbpf.binary_ops import handle_binary_op |
425 | | - |
426 | | - return handle_binary_op( |
| 529 | + return _handle_binary_op( |
427 | 530 | func, |
428 | 531 | module, |
429 | 532 | expr, |
|
0 commit comments