|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import IO, TypeVar |
| 4 | + |
| 5 | +from kirin import emit, interp |
| 6 | + |
| 7 | +from .stmts import Invoke, Return, Function |
| 8 | +from ._dialect import dialect |
| 9 | + |
| 10 | + |
| 11 | +@dialect.register(key="emit.julia") |
| 12 | +class Julia(interp.MethodTable): |
| 13 | + |
| 14 | + IO_t = TypeVar("IO_t", bound=IO) |
| 15 | + |
| 16 | + @interp.impl(Return) |
| 17 | + def return_( |
| 18 | + self, emit: emit.Julia[IO_t], frame: emit.JuliaFrame[IO_t], node: Return |
| 19 | + ): |
| 20 | + value = frame.get(node.value) |
| 21 | + frame.write_line(f"return {value}") |
| 22 | + |
| 23 | + @interp.impl(Invoke) |
| 24 | + def invoke( |
| 25 | + self, emit: emit.Julia[IO_t], frame: emit.JuliaFrame[IO_t], node: Invoke |
| 26 | + ): |
| 27 | + func_name = emit.callables.get(node.callee.code) |
| 28 | + if func_name is None: |
| 29 | + emit.callable_to_emit.append(node.callee.code) |
| 30 | + func_name = emit.callables.add(node.callee.code) |
| 31 | + |
| 32 | + _, call_expr = emit.call( |
| 33 | + node.callee.code, func_name, *frame.get_values(node.args) |
| 34 | + ) |
| 35 | + frame.write_line(f"{frame.ssa[node.result]} = {call_expr}") |
| 36 | + return (frame.ssa[node.result],) |
| 37 | + |
| 38 | + @interp.impl(Function) |
| 39 | + def function( |
| 40 | + self, emit_: emit.Julia[IO_t], frame: emit.JuliaFrame[IO_t], node: Function |
| 41 | + ): |
| 42 | + func_name = emit_.callables[node] |
| 43 | + frame.set(node.body.blocks[0].args[0], func_name) |
| 44 | + argnames_: list[str] = [] |
| 45 | + for arg in node.body.blocks[0].args[1:]: |
| 46 | + frame.set(arg, name := frame.ssa[arg]) |
| 47 | + argnames_.append(name) |
| 48 | + |
| 49 | + argnames = ", ".join(argnames_) |
| 50 | + frame.write_line(f"function {func_name}({argnames})") |
| 51 | + with frame.indent(): |
| 52 | + for block in node.body.blocks: |
| 53 | + frame.current_block = block |
| 54 | + frame.write_line(f"@label {frame.block[block]}") |
| 55 | + for arg in block.args: |
| 56 | + frame.set(arg, frame.ssa[arg]) |
| 57 | + |
| 58 | + for stmt in block.stmts: |
| 59 | + frame.current_stmt = stmt |
| 60 | + stmt_results = emit_.frame_eval(frame, stmt) |
| 61 | + |
| 62 | + match stmt_results: |
| 63 | + case tuple(): |
| 64 | + frame.set_values(stmt._results, stmt_results) |
| 65 | + case _: |
| 66 | + continue |
| 67 | + frame.write_line("end\n") |
0 commit comments