Skip to content

Commit bd6b7bb

Browse files
committed
Desugaring and runtime support for generated functions
I've opted to diverge from Base's lowering of `@ generated` code generators to add a semi-hidden `__context__` argument to the lowered code generator function in exact analogy to the context added to macros. This makes the code generator even more macro-like, and I've also reused the MacroExpansionContext. We need our own `GeneratedFunctionStub` here in order to construct the context, to call back into JuliaLowering's version of the lowering machinery in a precise way, and to propagate enhanced provenance information. This all turned out to be quite subtle, but JuliaLowering is able to almost fully integrate with the runtime without any changes to the runtime itself. Thus we can see the effort people have put into building abstractions for Cassette/etc has really paid off.
1 parent 3c39234 commit bd6b7bb

File tree

11 files changed

+485
-35
lines changed

11 files changed

+485
-35
lines changed

README.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,69 @@ In total, this expands a single "function definition" into seven methods.
398398
Note that the above is only a sketch! There's more fiddly details when `where`
399399
syntax comes in
400400

401+
### Desugaring of generated functions
402+
403+
A brief description of how this works. Let's consider the generated function
404+
405+
```julia
406+
function gen(x::NTuple{N}, y) where {N,T}
407+
shared = :shared
408+
# Unnecessary use of @generated, but it shows what's going on.
409+
if @generated
410+
quote
411+
maybe_gen = ($x, $N)
412+
end
413+
else
414+
maybe_gen = (typeof(x), N)
415+
end
416+
(shared, maybe_gen)
417+
end
418+
```
419+
420+
This is desugared into the following two function definitions. First, a code
421+
generator which will generate code for the body of the function, given the
422+
static parameters `N`, `T` and the positional arguments `x`, `y`.
423+
(`var"#self#"::Type{typeof(gen)}` is also provided by the Julia runtime to
424+
complete the full signature of `gen`, though the user won't normally use this.)
425+
426+
```julia
427+
function var"#gen@generator#0"(__context__::JuilaSyntax.MacroContext, N, T, var"#self#", x, y)
428+
gen_stuff = quote
429+
maybe_gen = ($x, $N)
430+
end
431+
quote
432+
shared = :shared
433+
$gen_stuff
434+
(shared, maybe_gen)
435+
end
436+
end
437+
```
438+
439+
Second, the non-generated version, using the `if @generated` else branches, and
440+
containing mostly normal code.
441+
442+
```julia
443+
function gen(x::NTuple{N}, y) where {N,T}
444+
$(Expr(:meta, :generated,
445+
Expr(:call, JuliaLowering.GeneratedFunctionStub,
446+
:var"#gen@generator#0", sourceref_of_gen,
447+
:(Core.svec(:var"#self", :x, :y))
448+
:(Core.svec(:N, :T)))))
449+
shared = :shared
450+
maybe_gen = (typeof(x), N)
451+
(shared, maybe_gen)
452+
end
453+
```
454+
455+
The one extra thing added here is the `Expr(:meta, :generated)` which is an
456+
expression creating a callable wrapper for the user's generator, to be
457+
evaluated at top level. This wrapper will then be invoked by the runtime
458+
whenever the user calls `gen` with a new signature and it's expected that a
459+
`CodeInfo` be returned from it. `JuliaLowering.GeneratedFunctionStub` differs
460+
from `Core.GeneratedFunctionStub` in that it contains extra provenance
461+
information (the `sourcref_of_gen`) and expects a `SyntaxTree` to be returned
462+
by the user's generator code.
463+
401464
## Pass 3: Scope analysis / binding resolution
402465

403466
This pass replaces variables with bindings of kind `K"BindingId"`,

src/desugaring.jl

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,7 +2213,7 @@ function method_def_expr(ctx, srcref, callex_srcref, method_table,
22132213
::K"SourceLocation"(callex_srcref)
22142214
]
22152215
[K"method"
2216-
method_table
2216+
isnothing(method_table) ? "nothing"::K"core" : method_table
22172217
method_metadata
22182218
[K"lambda"(body, is_toplevel_thunk=false)
22192219
[K"block" arg_names...]
@@ -2285,6 +2285,117 @@ function trim_used_typevars(ctx, arg_types, typevar_names, typevar_stmts)
22852285
return trimmed_typevar_names
22862286
end
22872287

2288+
function is_if_generated(ex)
2289+
kind(ex) == K"if" && kind(ex[1]) == K"generated"
2290+
end
2291+
2292+
# Return true if a function body contains a code generator from `@generated` in
2293+
# the form `[K"if" [K"generated"] ...]`
2294+
function is_generated(ex)
2295+
if is_if_generated(ex)
2296+
return true
2297+
elseif is_quoted(ex) || kind(ex) == K"function"
2298+
return false
2299+
else
2300+
return any(is_generated, children(ex))
2301+
end
2302+
end
2303+
2304+
function split_generated(ctx, ex, gen_part)
2305+
if is_leaf(ex)
2306+
ex
2307+
elseif is_if_generated(ex)
2308+
gen_part ? @ast(ctx, ex, [K"$" ex[2]]) : ex[3]
2309+
else
2310+
mapchildren(e->split_generated(ctx, e, gen_part), ctx, ex)
2311+
end
2312+
end
2313+
2314+
# Split @generated function body into two parts:
2315+
# * The code generator
2316+
# * The non-generated function body
2317+
function expand_function_generator(ctx, srcref, callex_srcref, func_name, func_name_str, body, arg_names, typevar_names)
2318+
gen_body = if is_if_generated(body)
2319+
body[2] # Simple case - don't need interpolation when the whole body is generated
2320+
else
2321+
expand_quote(ctx, @ast ctx body [K"block" split_generated(ctx, body, true)])
2322+
end
2323+
gen_name_str = reserve_module_binding_i(ctx.mod,
2324+
"#$(isnothing(func_name_str) ? "_" : func_name_str)@generator#")
2325+
gen_name = new_global_binding(ctx, body, gen_name_str, ctx.mod)
2326+
2327+
# Set up the arguments for the code generator
2328+
gen_arg_names = SyntaxList(ctx)
2329+
gen_arg_types = SyntaxList(ctx)
2330+
# Self arg
2331+
push!(gen_arg_names, new_local_binding(ctx, callex_srcref, "#self#"; kind=:argument))
2332+
push!(gen_arg_types, @ast ctx callex_srcref [K"function_type" gen_name])
2333+
# Macro expansion context arg
2334+
if kind(func_name) != K"Identifier"
2335+
TODO(func_name, "Which scope do we adopt for @generated generator `__context__` in this case?")
2336+
end
2337+
push!(gen_arg_names, adopt_scope(@ast(ctx, callex_srcref, "__context__"::K"Identifier"), func_name))
2338+
push!(gen_arg_types, @ast(ctx, callex_srcref, MacroContext::K"Value"))
2339+
# Trailing arguments to the generator are provided by the Julia runtime. They are:
2340+
# static_parameters... parent_function arg_types...
2341+
first_trailing_arg = length(gen_arg_names) + 1
2342+
append!(gen_arg_names, typevar_names)
2343+
append!(gen_arg_names, arg_names)
2344+
# Apply nospecialize to all arguments to prevent so much codegen and add
2345+
# Core.Any type for them
2346+
for i in first_trailing_arg:length(gen_arg_names)
2347+
gen_arg_names[i] = setmeta(gen_arg_names[i]; nospecialize=true)
2348+
push!(gen_arg_types, @ast ctx gen_arg_names[i] "Any"::K"core")
2349+
end
2350+
# Code generator definition
2351+
gen_func_method_defs = @ast ctx srcref [K"method_defs"
2352+
gen_name
2353+
method_def_expr(ctx, srcref, callex_srcref, nothing, SyntaxList(ctx), gen_arg_names,
2354+
gen_arg_types, gen_body, nothing)
2355+
]
2356+
2357+
# Extract non-generated body
2358+
nongen_body = @ast ctx body [K"block"
2359+
# The Julia runtime associates the code generator with the
2360+
# non-generated method by adding this meta to the body. This feels like
2361+
# a hack though since the generator ultimately gets attached to the
2362+
# method rather than the CodeInfo which we're putting it inside.
2363+
[K"meta"
2364+
"generated"::K"Symbol"
2365+
# The following is code to be evaluated at top level and will wrap
2366+
# whatever code comes from the user's generator into an appropriate
2367+
# K"lambda" (+ K"with_static_parameters") suitable for lowering
2368+
# into a CodeInfo.
2369+
#
2370+
# todo: As isolated top-level code, we don't actually want to apply
2371+
# the normal scope rules of the surrounding function ... it should
2372+
# technically have scope resolved at top level.
2373+
[K"new"
2374+
GeneratedFunctionStub::K"Value" # Use stub type from JuliaLowering
2375+
gen_name
2376+
# Truncate provenance to just the source file range, as this
2377+
# will live permanently in the IR and we probably don't want
2378+
# the full provenance tree and intermediate expressions
2379+
# (TODO: More truncation. We certainly don't want to store the
2380+
# source file either.)
2381+
sourceref(srcref)::K"Value"
2382+
[K"call"
2383+
"svec"::K"core"
2384+
"#self#"::K"Symbol"
2385+
(n.name_val::K"Symbol"(n) for n in arg_names[2:end])...
2386+
]
2387+
[K"call"
2388+
"svec"::K"core"
2389+
(n.name_val::K"Symbol"(n) for n in typevar_names)...
2390+
]
2391+
]
2392+
]
2393+
split_generated(ctx, body, false)
2394+
]
2395+
2396+
return gen_name, gen_func_method_defs, nongen_body
2397+
end
2398+
22882399
# Generate a method for every number of allowed optional arguments
22892400
# For example for `f(x, y=1, z=2)` we generate two additional methods
22902401
# f(x) = f(x, 1, 2)
@@ -2799,6 +2910,14 @@ function expand_function_def(ctx, ex, docs, rewrite_call=identity, rewrite_body=
27992910
]
28002911
end
28012912

2913+
gen_func_name = nothing
2914+
gen_func_method_defs = nothing
2915+
if is_generated(body)
2916+
gen_func_name, gen_func_method_defs, body =
2917+
expand_function_generator(ctx, ex, callex, name, name_str, body, arg_names, typevar_names)
2918+
2919+
end
2920+
28022921
if isnothing(keywords)
28032922
body_func_name, kw_func_method_defs = (nothing, nothing)
28042923
# NB: This check seems good as it statically catches any useless
@@ -2848,6 +2967,9 @@ function expand_function_def(ctx, ex, docs, rewrite_call=identity, rewrite_body=
28482967
end
28492968

28502969
@ast ctx ex [K"block"
2970+
if !isnothing(gen_func_name)
2971+
[K"function_decl"(gen_func_name) gen_func_name]
2972+
end
28512973
if !isnothing(body_func_name)
28522974
[K"function_decl"(body_func_name) body_func_name]
28532975
end
@@ -2857,6 +2979,7 @@ function expand_function_def(ctx, ex, docs, rewrite_call=identity, rewrite_body=
28572979
[K"scope_block"(scope_type=:hard)
28582980
[K"block"
28592981
new_typevar_stmts...
2982+
gen_func_method_defs
28602983
kw_func_method_defs
28612984
[K"method_defs"
28622985
isnothing(bare_func_name) ? "nothing"::K"core" : bare_func_name
@@ -3651,7 +3774,7 @@ function expand_struct_def(ctx, ex, docs)
36513774
typevar_in_bounds = any(type_params[i+1:end]) do param
36523775
# Check the bounds of subsequent type params
36533776
(_,lb,ub) = analyze_typevar(ctx, param)
3654-
# TODO: flisp lowering tests `lb` here so we also do. But
3777+
# todo: flisp lowering tests `lb` here so we also do. But
36553778
# in practice this doesn't seem to constrain `typevar_name`
36563779
# and the generated constructor doesn't work?
36573780
(!isnothing(ub) && contains_identifier(ub, typevar_name)) ||

src/eval.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,8 @@ function to_lowered_expr(mod, ex, ssa_offset=0)
230230
Core.SSAValue(ex.var_id + ssa_offset)
231231
elseif k == K"return"
232232
Core.ReturnNode(to_lowered_expr(mod, ex[1], ssa_offset))
233-
elseif is_quoted(k)
234-
if k == K"inert"
235-
ex[1]
236-
else
237-
TODO(ex, "Convert SyntaxTree to Expr")
238-
end
233+
elseif k == K"inert"
234+
ex[1]
239235
elseif k == K"code_info"
240236
funcname = ex.is_toplevel_thunk ?
241237
"top-level scope" :
@@ -269,6 +265,11 @@ function to_lowered_expr(mod, ex, ssa_offset=0)
269265
# TODO: put allow_partial back in once we update to the latest julia
270266
splice!(args, 4) # allow_partial
271267
Expr(:new_opaque_closure, args...)
268+
elseif k == K"meta"
269+
args = Any[to_lowered_expr(mod, e, ssa_offset) for e in children(ex)]
270+
# Unpack K"Symbol" QuoteNode as `Expr(:meta)` requires an identifier here.
271+
args[1] = args[1].value
272+
Expr(:meta, args...)
272273
else
273274
# Allowed forms according to https://docs.julialang.org/en/v1/devdocs/ast/
274275
#

src/kinds.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ function _register_kinds()
99
"BEGIN_EXTENSION_KINDS"
1010
# atomic fields or accesses (see `@atomic`)
1111
"atomic"
12+
# Flag for @generated parts of a functon
13+
"generated"
1214
# Temporary rooting of identifiers (GC.@preserve)
1315
"gc_preserve_begin"
1416
"gc_preserve_end"
@@ -46,6 +48,8 @@ function _register_kinds()
4648
# Catch-all for additional syntax extensions without the need to
4749
# extend `Kind`. Known extensions include:
4850
# locals, islocal
51+
# The content of an assertion is not considered to be quoted, so
52+
# use K"Symbol" or K"inert" inside where necessary.
4953
"extension"
5054
"END_EXTENSION_KINDS"
5155

src/linear_ir.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -796,11 +796,22 @@ function compile(ctx::LinearIRContext, ex, needs_value, in_tail_pos)
796796
end
797797
elseif k == K"gc_preserve_begin"
798798
makenode(ctx, ex, k, compile_args(ctx, children(ex)))
799-
elseif k == K"gc_preserve_end"
799+
elseif k == K"gc_preserve_end" || k == K"global" || k == K"const"
800800
if needs_value
801-
throw(LoweringError(ex, "misplaced label in value position"))
801+
throw(LoweringError(ex, "misplaced kind $k in value position"))
802802
end
803803
emit(ctx, ex)
804+
nothing
805+
elseif k == K"meta"
806+
emit(ctx, ex)
807+
if needs_value
808+
val = @ast ctx ex "nothing"::K"core"
809+
if in_tail_pos
810+
emit_return(ctx, val)
811+
else
812+
val
813+
end
814+
end
804815
elseif k == K"_while"
805816
end_label = make_label(ctx, ex)
806817
top_label = emit_label(ctx, ex)
@@ -821,12 +832,6 @@ function compile(ctx::LinearIRContext, ex, needs_value, in_tail_pos)
821832
if needs_value
822833
compile(ctx, nothing_(ctx, ex), needs_value, in_tail_pos)
823834
end
824-
elseif k == K"global" || k == K"const"
825-
if needs_value
826-
throw(LoweringError(ex, "misplaced declaration"))
827-
end
828-
emit(ctx, ex)
829-
nothing
830835
elseif k == K"isdefined" || k == K"captured_local" || k == K"throw_undef_if_not" ||
831836
k == K"boundscheck"
832837
if in_tail_pos
@@ -957,7 +962,12 @@ function _renumber(ctx, ssa_rewrites, slot_rewrites, label_table, ex)
957962
end
958963
end
959964
elseif k == K"meta"
960-
TODO(ex, "_renumber $k")
965+
# Somewhat-hack for Expr(:meta, :generated, gen) which has
966+
# weird top-level semantics for `gen`, but we still need to translate
967+
# the binding it contains to a globalref.
968+
mapchildren(ctx, ex) do e
969+
_renumber(ctx, ssa_rewrites, slot_rewrites, label_table, e)
970+
end
961971
elseif is_literal(k) || is_quoted(k)
962972
ex
963973
elseif k == K"label"
@@ -968,8 +978,6 @@ function _renumber(ctx, ssa_rewrites, slot_rewrites, label_table, ex)
968978
mapchildren(ctx, ex) do e
969979
_renumber(ctx, ssa_rewrites, slot_rewrites, label_table, e)
970980
end
971-
# TODO: foreigncall error check:
972-
# "ccall function name and library expression cannot reference local variables"
973981
end
974982
end
975983

src/macro_expansion.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666
#--------------------------------------------------
6767
struct MacroContext <: AbstractLoweringContext
6868
graph::SyntaxGraph
69-
macrocall::SyntaxTree
69+
macrocall::Union{SyntaxTree,LineNumberNode,SourceRef}
7070
scope_layer::ScopeLayer
7171
end
7272

0 commit comments

Comments
 (0)