Skip to content

Commit dbfe42c

Browse files
committed
fix: Fix typing, tests, etc
1 parent d07c0ef commit dbfe42c

File tree

6 files changed

+240
-13
lines changed

6 files changed

+240
-13
lines changed

examples/average.x

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,3 @@ fn average(x: f32, y: f32) -> f32:
88
summary: Returns the arithmetic mean of x and y.
99
```
1010
return (x + y) * 0.5;
11-
12-
fn main() -> i32:
13-
```
14-
title: main
15-
summary: Runs the print_star demo with a fixed size and exits with status 0.
16-
```
17-
print(average(10.0, 20.0))
18-
return 0

examples/fibonacci.x

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,12 @@ fn fib(x: i32) -> i32:
1111
return 1
1212
else:
1313
return fib(x-1)+fib(x-2)
14+
15+
16+
fn main() -> i32:
17+
```
18+
title: main
19+
summary: Runs the fibonacci demo and exits with status 0.
20+
```
21+
print(fib(10))
22+
return 0

src/arx/codegen.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""
2+
title: Arx LLVM-IR integration helpers.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import os
8+
import tempfile
9+
10+
from typing import Any, Callable, cast
11+
12+
import astx
13+
import xh
14+
15+
from irx import system
16+
from irx.builders.llvmliteir import (
17+
LLVMLiteIR as BaseLLVMLiteIR,
18+
)
19+
from irx.builders.llvmliteir import (
20+
LLVMLiteIRVisitor,
21+
is_fp_type,
22+
is_int_type,
23+
)
24+
from llvmlite import binding as llvm
25+
from llvmlite import ir
26+
from plum import dispatch
27+
28+
29+
class ArxLLVMLiteIRVisitor(LLVMLiteIRVisitor):
30+
"""
31+
title: Arx-specific LLVM-IR visitor customizations.
32+
"""
33+
34+
@dispatch
35+
def visit(self, node: astx.IfStmt) -> None:
36+
"""
37+
title: Generate LLVM IR for statement-style if blocks.
38+
parameters:
39+
node:
40+
type: astx.IfStmt
41+
"""
42+
self.visit(cast(Any, node.condition))
43+
cond_v = self.result_stack.pop() if self.result_stack else None
44+
if cond_v is None:
45+
raise Exception("codegen: Invalid condition expression.")
46+
47+
if is_fp_type(cond_v.type):
48+
cmp_instruction = self._llvm.ir_builder.fcmp_ordered
49+
zero_val = ir.Constant(cond_v.type, 0.0)
50+
else:
51+
cmp_instruction = self._llvm.ir_builder.icmp_signed
52+
zero_val = ir.Constant(cond_v.type, 0)
53+
54+
cond_v = cmp_instruction(
55+
"!=",
56+
cond_v,
57+
zero_val,
58+
)
59+
60+
then_bb = self._llvm.ir_builder.function.append_basic_block(
61+
"bb_if_then"
62+
)
63+
else_bb = self._llvm.ir_builder.function.append_basic_block(
64+
"bb_if_else"
65+
)
66+
merge_bb = self._llvm.ir_builder.function.append_basic_block(
67+
"bb_if_end"
68+
)
69+
70+
self._llvm.ir_builder.cbranch(cond_v, then_bb, else_bb)
71+
72+
self._llvm.ir_builder.position_at_start(then_bb)
73+
self.visit(cast(Any, node.then))
74+
then_v = self.result_stack.pop() if self.result_stack else None
75+
then_block_end = self._llvm.ir_builder.block
76+
then_terminated = then_block_end.terminator is not None
77+
if not then_terminated:
78+
self._llvm.ir_builder.branch(merge_bb)
79+
then_block_end = self._llvm.ir_builder.block
80+
81+
self._llvm.ir_builder.position_at_start(else_bb)
82+
else_v = None
83+
if node.else_ is not None:
84+
self.visit(cast(Any, node.else_))
85+
else_v = self.result_stack.pop() if self.result_stack else None
86+
else_block_end = self._llvm.ir_builder.block
87+
else_terminated = else_block_end.terminator is not None
88+
if not else_terminated:
89+
self._llvm.ir_builder.branch(merge_bb)
90+
else_block_end = self._llvm.ir_builder.block
91+
92+
if then_terminated and else_terminated:
93+
self._llvm.ir_builder.position_at_start(merge_bb)
94+
self._llvm.ir_builder.unreachable()
95+
return
96+
97+
self._llvm.ir_builder.position_at_start(merge_bb)
98+
99+
if (
100+
then_v is not None
101+
and else_v is not None
102+
and then_v.type == else_v.type
103+
and not then_terminated
104+
and not else_terminated
105+
):
106+
phi = self._llvm.ir_builder.phi(then_v.type, "iftmp")
107+
phi.add_incoming(then_v, then_block_end)
108+
phi.add_incoming(else_v, else_block_end)
109+
self.result_stack.append(phi)
110+
111+
@dispatch # type: ignore[no-redef]
112+
def visit(self, node: system.PrintExpr) -> None:
113+
"""
114+
title: Generate LLVM IR for PrintExpr with numeric support.
115+
parameters:
116+
node:
117+
type: system.PrintExpr
118+
"""
119+
self.visit(cast(Any, node.message))
120+
value = self.result_stack.pop() if self.result_stack else None
121+
if value is None:
122+
raise Exception("Invalid message in PrintExpr")
123+
124+
if isinstance(value.type, ir.PointerType) and (
125+
value.type.pointee == self._llvm.INT8_TYPE
126+
):
127+
ptr = value
128+
elif is_int_type(value.type):
129+
arg, fmt_str = self._normalize_int_for_printf(value)
130+
fmt_gv = self._get_or_create_format_global(fmt_str)
131+
ptr = self._snprintf_heap(fmt_gv, [arg])
132+
elif isinstance(
133+
value.type, (ir.FloatType, ir.DoubleType, ir.HalfType)
134+
):
135+
if isinstance(value.type, (ir.FloatType, ir.HalfType)):
136+
value_prom = self._llvm.ir_builder.fpext(
137+
value, self._llvm.DOUBLE_TYPE, "to_double"
138+
)
139+
else:
140+
value_prom = value
141+
fmt_gv = self._get_or_create_format_global("%.6f")
142+
ptr = self._snprintf_heap(fmt_gv, [value_prom])
143+
else:
144+
raise Exception(
145+
f"Unsupported print argument type: '{value.type}'."
146+
)
147+
148+
puts_fn = self._llvm.module.globals.get("puts")
149+
if puts_fn is None:
150+
puts_ty = ir.FunctionType(
151+
self._llvm.INT32_TYPE,
152+
[ir.PointerType(self._llvm.INT8_TYPE)],
153+
)
154+
puts_fn = ir.Function(self._llvm.module, puts_ty, name="puts")
155+
156+
self._llvm.ir_builder.call(cast(ir.Function, puts_fn), [ptr])
157+
self.result_stack.append(ir.Constant(self._llvm.INT32_TYPE, 0))
158+
159+
160+
class LLVMLiteIR(BaseLLVMLiteIR):
161+
"""
162+
title: LLVM-IR transpiler and compiler with Arx overrides.
163+
attributes:
164+
translator:
165+
type: ArxLLVMLiteIRVisitor
166+
"""
167+
168+
def __init__(self) -> None:
169+
"""
170+
title: Initialize LLVMIR.
171+
"""
172+
super().__init__()
173+
self.translator: ArxLLVMLiteIRVisitor = ArxLLVMLiteIRVisitor()
174+
175+
def build(self, node: astx.AST, output_file: str) -> None:
176+
"""
177+
title: >-
178+
Transpile the ASTx to LLVM-IR and build it to an executable file.
179+
parameters:
180+
node:
181+
type: astx.AST
182+
output_file:
183+
type: str
184+
"""
185+
self.translator = ArxLLVMLiteIRVisitor()
186+
result = self.translator.translate(node)
187+
188+
result_mod = llvm.parse_assembly(result)
189+
result_object = self.translator.target_machine.emit_object(result_mod)
190+
191+
with tempfile.NamedTemporaryFile(suffix="", delete=True) as temp_file:
192+
self.tmp_path = temp_file.name
193+
194+
file_path_o = f"{self.tmp_path}.o"
195+
with open(file_path_o, "wb") as file_handler:
196+
file_handler.write(result_object)
197+
198+
self.output_file = output_file
199+
200+
# fix xh typing
201+
clang: Callable[..., Any] = xh.clang
202+
clang(
203+
file_path_o,
204+
"-o",
205+
self.output_file,
206+
)
207+
os.chmod(self.output_file, 0o755)

src/arx/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111

1212
import astx
1313

14-
from irx.builders.llvmliteir import LLVMLiteIR
15-
14+
from arx.codegen import LLVMLiteIR
1615
from arx.io import ArxIO
1716
from arx.lexer import Lexer
1817
from arx.parser import Parser

tests/test_codegen_ast_output.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,36 @@
44

55
import pytest
66

7+
from arx.codegen import LLVMLiteIR
78
from arx.io import ArxIO
89
from arx.lexer import Lexer
910
from arx.parser import Parser
10-
from irx.builders.llvmliteir import LLVMLiteIR
1111

1212

1313
@pytest.mark.parametrize(
1414
"code",
1515
[
1616
"fn main():\n return 0.0 + 1.0",
1717
"fn main():\n return 1.0 + 2.0 * (3.0 - 2.0)",
18+
"fn main() -> i32:\n print(42)\n return 0",
19+
"fn main() -> i32:\n print(3.5)\n return 0",
20+
(
21+
"fn average(x: f32, y: f32) -> f32:\n"
22+
" return (x + y) * 0.5\n"
23+
"fn main() -> i32:\n"
24+
" print(average(10.0, 20.0))\n"
25+
" return 0"
26+
),
27+
(
28+
"fn fib(x: i32) -> i32:\n"
29+
" if x < 3:\n"
30+
" return 1\n"
31+
" else:\n"
32+
" return fib(x-1)+fib(x-2)\n"
33+
"fn main() -> i32:\n"
34+
" print(fib(10))\n"
35+
" return 0"
36+
),
1837
# "fn main():\n if (1 < 2):\n return 3\n else:\n return 2\n",
1938
],
2039
)
@@ -33,4 +52,5 @@ def test_ast_to_output(code: str) -> None:
3352

3453
module_ast = parser.parse(lexer.lex())
3554

36-
print(ir.translator.translate(module_ast))
55+
result = ir.translator.translate(module_ast)
56+
assert result

tests/test_codegen_file_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
import pytest
1010

11+
from arx.codegen import LLVMLiteIR
1112
from arx.io import ArxIO
1213
from arx.lexer import Lexer
1314
from arx.parser import Parser
14-
from irx.builders.llvmliteir import LLVMLiteIR
1515

1616
TMP_PATH = Path("/tmp/arxtmp")
1717
TMP_PATH.mkdir(exist_ok=True)

0 commit comments

Comments
 (0)