From ec52ff3568a5faed8ee7cbac8325957c68480f1e Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 27 Jul 2025 12:56:44 -0400 Subject: [PATCH] Start the code to raise from LLVM within Julia --- lib/ReactantCore/src/ReactantCore.jl | 60 +++++++++++++++++++++++++++- src/Reactant.jl | 2 +- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index 1a21d8a922..bd5a35fe90 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -3,7 +3,7 @@ module ReactantCore using ExpressionExplorer: ExpressionExplorer using MacroTools: MacroTools -export @trace, within_compile, MissingTracedValue, promote_to_traced +export @trace, @raise, within_compile, MissingTracedValue, promote_to_traced # Traits function is_traced((@nospecialize x::T), seen=Base.IdSet()) where {T} @@ -637,4 +637,62 @@ Given an AbstractArray{TracedRNumber}, return or create an equivalent TracedRArr """ function materialize_traced_array end +""" + @raise expr + +Signals to raise the code block `expr` from LLVM during compilation. +""" +macro raise(args...) + expr = first(args) + if length(args) > 1 + error("@raise only supports a single block of code, got $(length(args))") + end + esc(_raise(expr)) +end + +function _raise(expr) + ee = ExpressionExplorer.compute_symbols_state(expr) + refs = ee.references + assi = ee.assignments + + # Prevent boxing when esc to we generate new names + arg_syms = [gensym(s) for s in refs] + ass_syms = [gensym(s) for s in refs] + + + expr_with_args = MacroTools.postwalk(expr) do x + for (i, r) in enumerate(refs) + if x == r + return arg_syms[i] + end + end + for (i, r) in enumerate(assi) + if x == r + return ass_syms[i] + end + end + + return x + end + + func = gensym(:raise_fn) + reactant_code_block = quote + function $func($(arg_syms...)) + $(expr_with_args) + end + ($func, ($(refs...),)) + end + cond_val(s) = :(@isdefined($s) ? $s : nothing) + + return quote + if $(within_compile)() && + $(any)($(is_traced), ($(cond_val.(refs)...),)) + $(reactant_code_block) + else + $(expr) + end + end +end + + end diff --git a/src/Reactant.jl b/src/Reactant.jl index 02893f9516..cf0dbe05c7 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -1,7 +1,7 @@ module Reactant using ReactantCore: - ReactantCore, @trace, within_compile, MissingTracedValue, materialize_traced_array + ReactantCore, @trace, @raise, within_compile, MissingTracedValue, materialize_traced_array using LinearAlgebra: LinearAlgebra using Random: Random, AbstractRNG