Skip to content

Start raise macro #1492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module ReactantCore
using ExpressionExplorer: ExpressionExplorer
using MacroTools: MacroTools

export @trace, within_compile, MissingTracedValue, promote_to_traced
export @trace, @raise, within_compile, MissingTracedValue, promote_to_traced

# Traits
function is_traced((@nospecialize x::T), seen=Base.IdSet()) where {T}
Expand Down Expand Up @@ -637,4 +637,62 @@ Given an AbstractArray{TracedRNumber}, return or create an equivalent TracedRArr
"""
function materialize_traced_array end

"""
@raise expr

Signals to raise the code block `expr` from LLVM during compilation.
"""
macro raise(args...)
expr = first(args)
if length(args) > 1
error("@raise only supports a single block of code, got $(length(args))")
end
esc(_raise(expr))
end

function _raise(expr)
ee = ExpressionExplorer.compute_symbols_state(expr)
refs = ee.references
assi = ee.assignments

# Prevent boxing when esc to we generate new names
arg_syms = [gensym(s) for s in refs]
ass_syms = [gensym(s) for s in refs]


expr_with_args = MacroTools.postwalk(expr) do x
for (i, r) in enumerate(refs)
if x == r
return arg_syms[i]
end
end
for (i, r) in enumerate(assi)
if x == r
return ass_syms[i]
end
end

return x
end

func = gensym(:raise_fn)
reactant_code_block = quote
function $func($(arg_syms...))
$(expr_with_args)
end
($func, ($(refs...),))
end
cond_val(s) = :(@isdefined($s) ? $s : nothing)

return quote
if $(within_compile)() &&
$(any)($(is_traced), ($(cond_val.(refs)...),))
$(reactant_code_block)
else
$(expr)
end
end
end


end
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Reactant

using ReactantCore:
ReactantCore, @trace, within_compile, MissingTracedValue, materialize_traced_array
ReactantCore, @trace, @raise, within_compile, MissingTracedValue, materialize_traced_array

using LinearAlgebra: LinearAlgebra
using Random: Random, AbstractRNG
Expand Down