|
| 1 | +from llvmlite import ir, binding |
| 2 | +import atexit |
| 3 | + |
| 4 | +_SUPPORTED = { |
| 5 | + "int32": (ir.IntType(32), 4), |
| 6 | + "int64": (ir.IntType(64), 8), |
| 7 | + "float32": (ir.FloatType(), 4), |
| 8 | + "float64": (ir.DoubleType(), 8), |
| 9 | +} |
| 10 | + |
| 11 | +_engines = {} |
| 12 | +_target_machine = None |
| 13 | +_fn_ptr_cache = {} |
| 14 | + |
| 15 | +def _cleanup(): |
| 16 | + """Clean up LLVM resources on exit.""" |
| 17 | + global _engines, _target_machine, _fn_ptr_cache |
| 18 | + _engines.clear() |
| 19 | + _target_machine = None |
| 20 | + _fn_ptr_cache.clear() |
| 21 | + |
| 22 | +atexit.register(_cleanup) |
| 23 | + |
| 24 | +def _ensure_target_machine(): |
| 25 | + global _target_machine |
| 26 | + if _target_machine is not None: |
| 27 | + return |
| 28 | + |
| 29 | + try: |
| 30 | + binding.initialize() |
| 31 | + binding.initialize_native_target() |
| 32 | + binding.initialize_native_asmprinter() |
| 33 | + |
| 34 | + target = binding.Target.from_default_triple() |
| 35 | + _target_machine = target.create_target_machine() |
| 36 | + except Exception as e: |
| 37 | + raise RuntimeError(f"Failed to initialize LLVM target machine: {e}") |
| 38 | + |
| 39 | +def get_bubble_sort_ptr(dtype: str) -> int: |
| 40 | + """Get function pointer for bubble sort with specified dtype.""" |
| 41 | + dtype = dtype.lower().strip() |
| 42 | + if dtype not in _SUPPORTED: |
| 43 | + raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}") |
| 44 | + |
| 45 | + return _materialize(dtype) |
| 46 | + |
| 47 | +def _build_bubble_sort_ir(dtype: str) -> str: |
| 48 | + if dtype not in _SUPPORTED: |
| 49 | + raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}") |
| 50 | + |
| 51 | + T, _ = _SUPPORTED[dtype] |
| 52 | + i32 = ir.IntType(32) |
| 53 | + i64 = ir.IntType(64) |
| 54 | + |
| 55 | + mod = ir.Module(name=f"bubble_sort_{dtype}_module") |
| 56 | + fn_name = f"bubble_sort_{dtype}" |
| 57 | + |
| 58 | + fn_ty = ir.FunctionType(ir.VoidType(), [T.as_pointer(), i32]) |
| 59 | + fn = ir.Function(mod, fn_ty, name=fn_name) |
| 60 | + |
| 61 | + arr, n = fn.args |
| 62 | + arr.name, n.name = "arr", "n" |
| 63 | + |
| 64 | + b_entry = fn.append_basic_block("entry") |
| 65 | + b_outer = fn.append_basic_block("outer") |
| 66 | + b_inner_init = fn.append_basic_block("inner.init") |
| 67 | + b_inner = fn.append_basic_block("inner") |
| 68 | + b_body = fn.append_basic_block("body") |
| 69 | + b_swap = fn.append_basic_block("swap") |
| 70 | + b_inner_latch = fn.append_basic_block("inner.latch") |
| 71 | + b_outer_latch = fn.append_basic_block("outer.latch") |
| 72 | + b_exit = fn.append_basic_block("exit") |
| 73 | + |
| 74 | + b = ir.IRBuilder(b_entry) |
| 75 | + |
| 76 | + cond_trivial = b.icmp_signed("<=", n, ir.Constant(i32, 1)) |
| 77 | + b.cbranch(cond_trivial, b_exit, b_outer) |
| 78 | + |
| 79 | + b.position_at_end(b_outer) |
| 80 | + i_phi = b.phi(i32, name="i") |
| 81 | + i_phi.add_incoming(ir.Constant(i32, 0), b_entry) |
| 82 | + |
| 83 | + n1 = b.sub(n, ir.Constant(i32, 1), name="n_minus_1") |
| 84 | + cond_outer = b.icmp_signed("<", i_phi, n1) |
| 85 | + b.cbranch(cond_outer, b_inner_init, b_exit) |
| 86 | + |
| 87 | + b.position_at_end(b_inner_init) |
| 88 | + |
| 89 | + inner_limit = b.sub(n1, i_phi, name="inner_limit") |
| 90 | + b.branch(b_inner) |
| 91 | + |
| 92 | + b.position_at_end(b_inner) |
| 93 | + j_phi = b.phi(i32, name="j") |
| 94 | + j_phi.add_incoming(ir.Constant(i32, 0), b_inner_init) |
| 95 | + |
| 96 | + cond_inner = b.icmp_signed("<", j_phi, inner_limit) |
| 97 | + b.cbranch(cond_inner, b_body, b_outer_latch) |
| 98 | + |
| 99 | + b.position_at_end(b_body) |
| 100 | + j64 = b.sext(j_phi, i64) |
| 101 | + jp1 = b.add(j_phi, ir.Constant(i32, 1)) |
| 102 | + jp1_64 = b.sext(jp1, i64) |
| 103 | + |
| 104 | + ptr_j = b.gep(arr, [j64], inbounds=True) |
| 105 | + ptr_jp1 = b.gep(arr, [jp1_64], inbounds=True) |
| 106 | + |
| 107 | + val_j = b.load(ptr_j) |
| 108 | + val_jp1 = b.load(ptr_jp1) |
| 109 | + |
| 110 | + if isinstance(T, ir.IntType): |
| 111 | + should_swap = b.icmp_signed(">", val_j, val_jp1) |
| 112 | + else: |
| 113 | + should_swap = b.fcmp_ordered(">", val_j, val_jp1) |
| 114 | + |
| 115 | + b.cbranch(should_swap, b_swap, b_inner_latch) |
| 116 | + |
| 117 | + b.position_at_end(b_swap) |
| 118 | + b.store(val_jp1, ptr_j) |
| 119 | + b.store(val_j, ptr_jp1) |
| 120 | + b.branch(b_inner_latch) |
| 121 | + |
| 122 | + b.position_at_end(b_inner_latch) |
| 123 | + j_next = b.add(j_phi, ir.Constant(i32, 1)) |
| 124 | + j_phi.add_incoming(j_next, b_inner_latch) |
| 125 | + b.branch(b_inner) |
| 126 | + |
| 127 | + b.position_at_end(b_outer_latch) |
| 128 | + i_next = b.add(i_phi, ir.Constant(i32, 1)) |
| 129 | + i_phi.add_incoming(i_next, b_outer_latch) |
| 130 | + b.branch(b_outer) |
| 131 | + |
| 132 | + b.position_at_end(b_exit) |
| 133 | + b.ret_void() |
| 134 | + |
| 135 | + return str(mod) |
| 136 | + |
| 137 | +def _materialize(dtype: str) -> int: |
| 138 | + _ensure_target_machine() |
| 139 | + |
| 140 | + if dtype in _fn_ptr_cache: |
| 141 | + return _fn_ptr_cache[dtype] |
| 142 | + |
| 143 | + try: |
| 144 | + llvm_ir = _build_bubble_sort_ir(dtype) |
| 145 | + mod = binding.parse_assembly(llvm_ir) |
| 146 | + mod.verify() |
| 147 | + |
| 148 | + engine = binding.create_mcjit_compiler(mod, _target_machine) |
| 149 | + engine.finalize_object() |
| 150 | + engine.run_static_constructors() |
| 151 | + |
| 152 | + addr = engine.get_function_address(f"bubble_sort_{dtype}") |
| 153 | + if not addr: |
| 154 | + raise RuntimeError(f"Failed to get address for bubble_sort_{dtype}") |
| 155 | + |
| 156 | + _fn_ptr_cache[dtype] = addr |
| 157 | + _engines[dtype] = engine |
| 158 | + |
| 159 | + return addr |
| 160 | + |
| 161 | + except Exception as e: |
| 162 | + raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}") |
0 commit comments