|
| 1 | +""" |
| 2 | +title: Arx LLVM-IR integration helpers. |
| 3 | +""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import os |
| 8 | +import tempfile |
| 9 | + |
| 10 | +from typing import Any, Callable, cast |
| 11 | + |
| 12 | +import astx |
| 13 | +import xh |
| 14 | + |
| 15 | +from irx import system |
| 16 | +from irx.builders.llvmliteir import ( |
| 17 | + LLVMLiteIR as BaseLLVMLiteIR, |
| 18 | +) |
| 19 | +from irx.builders.llvmliteir import ( |
| 20 | + LLVMLiteIRVisitor, |
| 21 | + is_fp_type, |
| 22 | + is_int_type, |
| 23 | +) |
| 24 | +from llvmlite import binding as llvm |
| 25 | +from llvmlite import ir |
| 26 | +from plum import dispatch |
| 27 | + |
| 28 | + |
| 29 | +class ArxLLVMLiteIRVisitor(LLVMLiteIRVisitor): |
| 30 | + """ |
| 31 | + title: Arx-specific LLVM-IR visitor customizations. |
| 32 | + """ |
| 33 | + |
| 34 | + @dispatch |
| 35 | + def visit(self, node: astx.IfStmt) -> None: |
| 36 | + """ |
| 37 | + title: Generate LLVM IR for statement-style if blocks. |
| 38 | + parameters: |
| 39 | + node: |
| 40 | + type: astx.IfStmt |
| 41 | + """ |
| 42 | + self.visit(cast(Any, node.condition)) |
| 43 | + cond_v = self.result_stack.pop() if self.result_stack else None |
| 44 | + if cond_v is None: |
| 45 | + raise Exception("codegen: Invalid condition expression.") |
| 46 | + |
| 47 | + if is_fp_type(cond_v.type): |
| 48 | + cmp_instruction = self._llvm.ir_builder.fcmp_ordered |
| 49 | + zero_val = ir.Constant(cond_v.type, 0.0) |
| 50 | + else: |
| 51 | + cmp_instruction = self._llvm.ir_builder.icmp_signed |
| 52 | + zero_val = ir.Constant(cond_v.type, 0) |
| 53 | + |
| 54 | + cond_v = cmp_instruction( |
| 55 | + "!=", |
| 56 | + cond_v, |
| 57 | + zero_val, |
| 58 | + ) |
| 59 | + |
| 60 | + then_bb = self._llvm.ir_builder.function.append_basic_block( |
| 61 | + "bb_if_then" |
| 62 | + ) |
| 63 | + else_bb = self._llvm.ir_builder.function.append_basic_block( |
| 64 | + "bb_if_else" |
| 65 | + ) |
| 66 | + merge_bb = self._llvm.ir_builder.function.append_basic_block( |
| 67 | + "bb_if_end" |
| 68 | + ) |
| 69 | + |
| 70 | + self._llvm.ir_builder.cbranch(cond_v, then_bb, else_bb) |
| 71 | + |
| 72 | + self._llvm.ir_builder.position_at_start(then_bb) |
| 73 | + self.visit(cast(Any, node.then)) |
| 74 | + then_v = self.result_stack.pop() if self.result_stack else None |
| 75 | + then_block_end = self._llvm.ir_builder.block |
| 76 | + then_terminated = then_block_end.terminator is not None |
| 77 | + if not then_terminated: |
| 78 | + self._llvm.ir_builder.branch(merge_bb) |
| 79 | + then_block_end = self._llvm.ir_builder.block |
| 80 | + |
| 81 | + self._llvm.ir_builder.position_at_start(else_bb) |
| 82 | + else_v = None |
| 83 | + if node.else_ is not None: |
| 84 | + self.visit(cast(Any, node.else_)) |
| 85 | + else_v = self.result_stack.pop() if self.result_stack else None |
| 86 | + else_block_end = self._llvm.ir_builder.block |
| 87 | + else_terminated = else_block_end.terminator is not None |
| 88 | + if not else_terminated: |
| 89 | + self._llvm.ir_builder.branch(merge_bb) |
| 90 | + else_block_end = self._llvm.ir_builder.block |
| 91 | + |
| 92 | + if then_terminated and else_terminated: |
| 93 | + self._llvm.ir_builder.position_at_start(merge_bb) |
| 94 | + self._llvm.ir_builder.unreachable() |
| 95 | + return |
| 96 | + |
| 97 | + self._llvm.ir_builder.position_at_start(merge_bb) |
| 98 | + |
| 99 | + if ( |
| 100 | + then_v is not None |
| 101 | + and else_v is not None |
| 102 | + and then_v.type == else_v.type |
| 103 | + and not then_terminated |
| 104 | + and not else_terminated |
| 105 | + ): |
| 106 | + phi = self._llvm.ir_builder.phi(then_v.type, "iftmp") |
| 107 | + phi.add_incoming(then_v, then_block_end) |
| 108 | + phi.add_incoming(else_v, else_block_end) |
| 109 | + self.result_stack.append(phi) |
| 110 | + |
| 111 | + @dispatch # type: ignore[no-redef] |
| 112 | + def visit(self, node: system.PrintExpr) -> None: |
| 113 | + """ |
| 114 | + title: Generate LLVM IR for PrintExpr with numeric support. |
| 115 | + parameters: |
| 116 | + node: |
| 117 | + type: system.PrintExpr |
| 118 | + """ |
| 119 | + self.visit(cast(Any, node.message)) |
| 120 | + value = self.result_stack.pop() if self.result_stack else None |
| 121 | + if value is None: |
| 122 | + raise Exception("Invalid message in PrintExpr") |
| 123 | + |
| 124 | + if isinstance(value.type, ir.PointerType) and ( |
| 125 | + value.type.pointee == self._llvm.INT8_TYPE |
| 126 | + ): |
| 127 | + ptr = value |
| 128 | + elif is_int_type(value.type): |
| 129 | + arg, fmt_str = self._normalize_int_for_printf(value) |
| 130 | + fmt_gv = self._get_or_create_format_global(fmt_str) |
| 131 | + ptr = self._snprintf_heap(fmt_gv, [arg]) |
| 132 | + elif isinstance( |
| 133 | + value.type, (ir.FloatType, ir.DoubleType, ir.HalfType) |
| 134 | + ): |
| 135 | + if isinstance(value.type, (ir.FloatType, ir.HalfType)): |
| 136 | + value_prom = self._llvm.ir_builder.fpext( |
| 137 | + value, self._llvm.DOUBLE_TYPE, "to_double" |
| 138 | + ) |
| 139 | + else: |
| 140 | + value_prom = value |
| 141 | + fmt_gv = self._get_or_create_format_global("%.6f") |
| 142 | + ptr = self._snprintf_heap(fmt_gv, [value_prom]) |
| 143 | + else: |
| 144 | + raise Exception( |
| 145 | + f"Unsupported print argument type: '{value.type}'." |
| 146 | + ) |
| 147 | + |
| 148 | + puts_fn = self._llvm.module.globals.get("puts") |
| 149 | + if puts_fn is None: |
| 150 | + puts_ty = ir.FunctionType( |
| 151 | + self._llvm.INT32_TYPE, |
| 152 | + [ir.PointerType(self._llvm.INT8_TYPE)], |
| 153 | + ) |
| 154 | + puts_fn = ir.Function(self._llvm.module, puts_ty, name="puts") |
| 155 | + |
| 156 | + self._llvm.ir_builder.call(cast(ir.Function, puts_fn), [ptr]) |
| 157 | + self.result_stack.append(ir.Constant(self._llvm.INT32_TYPE, 0)) |
| 158 | + |
| 159 | + |
| 160 | +class LLVMLiteIR(BaseLLVMLiteIR): |
| 161 | + """ |
| 162 | + title: LLVM-IR transpiler and compiler with Arx overrides. |
| 163 | + attributes: |
| 164 | + translator: |
| 165 | + type: ArxLLVMLiteIRVisitor |
| 166 | + """ |
| 167 | + |
| 168 | + def __init__(self) -> None: |
| 169 | + """ |
| 170 | + title: Initialize LLVMIR. |
| 171 | + """ |
| 172 | + super().__init__() |
| 173 | + self.translator: ArxLLVMLiteIRVisitor = ArxLLVMLiteIRVisitor() |
| 174 | + |
| 175 | + def build(self, node: astx.AST, output_file: str) -> None: |
| 176 | + """ |
| 177 | + title: >- |
| 178 | + Transpile the ASTx to LLVM-IR and build it to an executable file. |
| 179 | + parameters: |
| 180 | + node: |
| 181 | + type: astx.AST |
| 182 | + output_file: |
| 183 | + type: str |
| 184 | + """ |
| 185 | + self.translator = ArxLLVMLiteIRVisitor() |
| 186 | + result = self.translator.translate(node) |
| 187 | + |
| 188 | + result_mod = llvm.parse_assembly(result) |
| 189 | + result_object = self.translator.target_machine.emit_object(result_mod) |
| 190 | + |
| 191 | + with tempfile.NamedTemporaryFile(suffix="", delete=True) as temp_file: |
| 192 | + self.tmp_path = temp_file.name |
| 193 | + |
| 194 | + file_path_o = f"{self.tmp_path}.o" |
| 195 | + with open(file_path_o, "wb") as file_handler: |
| 196 | + file_handler.write(result_object) |
| 197 | + |
| 198 | + self.output_file = output_file |
| 199 | + |
| 200 | + # fix xh typing |
| 201 | + clang: Callable[..., Any] = xh.clang |
| 202 | + clang( |
| 203 | + file_path_o, |
| 204 | + "-o", |
| 205 | + self.output_file, |
| 206 | + ) |
| 207 | + os.chmod(self.output_file, 0o755) |
0 commit comments