Skip to content

Commit 10667c5

Browse files
committed
Fix for toplevel-preserving statements in closure conversion
Previously, closures were lifted to the outermost level of the toplevel thunk. Instead they should be kept inside any `if`, `try` and `block` top level statements, but lifted out of most other constructs.
1 parent 7c05bd1 commit 10667c5

File tree

6 files changed

+325
-262
lines changed

6 files changed

+325
-262
lines changed

src/closure_conversion.jl

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ struct ClosureConversionCtx{GraphType} <: AbstractLoweringContext
1515
closure_bindings::Dict{IdTag,ClosureBindings}
1616
capture_rewriting::Union{Nothing,ClosureInfo{GraphType},SyntaxList{GraphType}}
1717
lambda_bindings::LambdaBindings
18+
# True if we're in a section of code which preserves top-level sequencing
19+
# such that closure types can be emitted inline with other code.
20+
is_toplevel_seq_point::Bool
1821
toplevel_stmts::SyntaxList{GraphType}
1922
closure_infos::Dict{IdTag,ClosureInfo{GraphType}}
2023
end
@@ -23,8 +26,8 @@ function ClosureConversionCtx(graph::GraphType, bindings::Bindings,
2326
mod::Module, closure_bindings::Dict{IdTag,ClosureBindings},
2427
lambda_bindings::LambdaBindings) where {GraphType}
2528
ClosureConversionCtx{GraphType}(
26-
graph, bindings, mod, closure_bindings, nothing, lambda_bindings, SyntaxList(graph),
27-
Dict{IdTag,ClosureInfo{GraphType}}())
29+
graph, bindings, mod, closure_bindings, nothing,
30+
lambda_bindings, false, SyntaxList(graph), Dict{IdTag,ClosureInfo{GraphType}}())
2831
end
2932

3033
function current_lambda_bindings(ctx::ClosureConversionCtx)
@@ -288,6 +291,28 @@ function is_self_captured(ctx, x)
288291
!isnothing(lbinfo) && lbinfo.is_captured
289292
end
290293

294+
# Map the children of `ex` through _convert_closures, lifting any toplevel
295+
# closure definition statements to occur before the other content of `ex`.
296+
function map_cl_convert(ctx::ClosureConversionCtx, ex, toplevel_preserving)
297+
if ctx.is_toplevel_seq_point && !toplevel_preserving
298+
toplevel_stmts = SyntaxList(ctx)
299+
ctx2 = ClosureConversionCtx(ctx.graph, ctx.bindings, ctx.mod,
300+
ctx.closure_bindings, ctx.capture_rewriting, ctx.lambda_bindings,
301+
false, toplevel_stmts, ctx.closure_infos)
302+
res = mapchildren(e->_convert_closures(ctx2, e), ctx2, ex)
303+
if isempty(toplevel_stmts)
304+
res
305+
else
306+
@ast ctx ex [K"block"
307+
toplevel_stmts...
308+
res
309+
]
310+
end
311+
else
312+
mapchildren(e->_convert_closures(ctx, e), ctx, ex)
313+
end
314+
end
315+
291316
function _convert_closures(ctx::ClosureConversionCtx, ex)
292317
k = kind(ex)
293318
if k == K"BindingId"
@@ -358,7 +383,10 @@ function _convert_closures(ctx::ClosureConversionCtx, ex)
358383
"#$(join(closure_binds.name_stack, "#"))##")
359384
closure_type_def, closure_type_ =
360385
type_for_closure(ctx, ex, name_str, field_syms, field_is_box)
361-
push!(ctx.toplevel_stmts, closure_type_def)
386+
if !ctx.is_toplevel_seq_point
387+
push!(ctx.toplevel_stmts, closure_type_def)
388+
closure_type_def = nothing
389+
end
362390
closure_info = ClosureInfo(closure_type_, field_syms, field_inds)
363391
ctx.closure_infos[func_name_id] = closure_info
364392
type_params = SyntaxList(ctx)
@@ -375,6 +403,7 @@ function _convert_closures(ctx::ClosureConversionCtx, ex)
375403
end
376404
end
377405
@ast ctx ex [K"block"
406+
closure_type_def
378407
closure_type := if isempty(type_params)
379408
closure_type_
380409
else
@@ -395,7 +424,7 @@ function _convert_closures(ctx::ClosureConversionCtx, ex)
395424
# binding for `func_name` if it doesn't exist.
396425
@ast ctx ex [K"block"
397426
[K"method" func_name]
398-
::K"TOMBSTONE"
427+
::K"TOMBSTONE" # <- function_decl should not be used in value position
399428
]
400429
end
401430
elseif k == K"function_type"
@@ -410,17 +439,17 @@ function _convert_closures(ctx::ClosureConversionCtx, ex)
410439
is_closure = kind(name) == K"BindingId" && lookup_binding(ctx, name).kind === :local
411440
cap_rewrite = is_closure ? ctx.closure_infos[name.var_id] : nothing
412441
ctx2 = ClosureConversionCtx(ctx.graph, ctx.bindings, ctx.mod,
413-
ctx.closure_bindings, cap_rewrite, ctx.lambda_bindings,
414-
ctx.toplevel_stmts, ctx.closure_infos)
415-
body = _convert_closures(ctx2, ex[2])
442+
ctx.closure_bindings, cap_rewrite, ctx.lambda_bindings,
443+
ctx.is_toplevel_seq_point, ctx.toplevel_stmts, ctx.closure_infos)
444+
body = map_cl_convert(ctx2, ex[2], false)
416445
if is_closure
417-
# Move methods to top level
418-
# FIXME: Probably lots more work to do to make this correct
419-
# Especially
420-
# * Renumbering SSA vars
421-
# * Ensuring that moved locals become slots in the top level thunk
422-
push!(ctx.toplevel_stmts, body)
423-
@ast ctx ex (::K"TOMBSTONE")
446+
if ctx.is_toplevel_seq_point
447+
body
448+
else
449+
# Move methods out to a top-level sequence point.
450+
push!(ctx.toplevel_stmts, body)
451+
@ast ctx ex (::K"TOMBSTONE")
452+
end
424453
else
425454
@ast ctx ex [K"block"
426455
body
@@ -435,8 +464,8 @@ function _convert_closures(ctx::ClosureConversionCtx, ex)
435464
capture_rewrites = ClosureInfo(ex #=unused=#, field_syms, field_inds)
436465

437466
ctx2 = ClosureConversionCtx(ctx.graph, ctx.bindings, ctx.mod,
438-
ctx.closure_bindings, capture_rewrites, ctx.lambda_bindings,
439-
ctx.toplevel_stmts, ctx.closure_infos)
467+
ctx.closure_bindings, capture_rewrites, ctx.lambda_bindings,
468+
false, ctx.toplevel_stmts, ctx.closure_infos)
440469

441470
init_closure_args = SyntaxList(ctx)
442471
for id in field_orig_bindings
@@ -457,31 +486,36 @@ function _convert_closures(ctx::ClosureConversionCtx, ex)
457486
init_closure_args...
458487
]
459488
else
460-
mapchildren(e->_convert_closures(ctx, e), ctx, ex)
489+
# A small number of kinds are toplevel-preserving in terms of closure
490+
# closure definitions will be lifted out into `toplevel_stmts` if they
491+
# occur inside `ex`.
492+
toplevel_seq_preserving = k == K"if" || k == K"elseif" || k == K"block" ||
493+
k == K"tryfinally" || k == K"trycatchelse"
494+
map_cl_convert(ctx, ex, toplevel_seq_preserving)
461495
end
462496
end
463497

464498
function closure_convert_lambda(ctx, ex)
465499
@assert kind(ex) == K"lambda"
466-
body_stmts = SyntaxList(ctx)
467-
toplevel_stmts = ex.is_toplevel_thunk ? body_stmts : ctx.toplevel_stmts
468500
lambda_bindings = ex.lambda_bindings
469501
interpolations = nothing
470502
if isnothing(ctx.capture_rewriting)
503+
# Global method which may capture locals
471504
interpolations = SyntaxList(ctx)
472505
cap_rewrite = interpolations
473506
else
474507
cap_rewrite = ctx.capture_rewriting
475508
end
476509
ctx2 = ClosureConversionCtx(ctx.graph, ctx.bindings, ctx.mod,
477510
ctx.closure_bindings, cap_rewrite, lambda_bindings,
478-
toplevel_stmts, ctx.closure_infos)
511+
ex.is_toplevel_thunk, ctx.toplevel_stmts, ctx.closure_infos)
479512
lambda_children = SyntaxList(ctx)
480513
args = ex[1]
481514
push!(lambda_children, args)
482515
push!(lambda_children, ex[2])
483516

484517
# Add box initializations for arguments which are captured by an inner lambda
518+
body_stmts = SyntaxList(ctx)
485519
for arg in children(args)
486520
kind(arg) != K"Placeholder" || continue
487521
if is_boxed(ctx, arg)
@@ -491,8 +525,7 @@ function closure_convert_lambda(ctx, ex)
491525
])
492526
end
493527
end
494-
# Convert body. Note that _convert_closures may call `push!(body_stmts, e)`
495-
# internally for any expressions `e` which need to be moved to top level.
528+
# Convert body.
496529
input_body_stmts = kind(ex[3]) != K"block" ? ex[3:3] : ex[3][1:end]
497530
for e in input_body_stmts
498531
push!(body_stmts, _convert_closures(ctx2, e))
@@ -538,5 +571,8 @@ function convert_closures(ctx::VariableAnalysisContext, ex)
538571
ctx = ClosureConversionCtx(ctx.graph, ctx.bindings, ctx.mod,
539572
ctx.closure_bindings, ex.lambda_bindings)
540573
ex1 = closure_convert_lambda(ctx, ex)
574+
if !isempty(ctx.toplevel_stmts)
575+
throw(LoweringError(first(ctx.toplevel_stmts), "Top level code was found outside any top level context. `@generated` functions may not contain closures, including `do` syntax and generators/comprehension"))
576+
end
541577
ctx, ex1
542578
end

test/closures.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
2-
@testset "Functions" begin
1+
@testset "Closures" begin
32

43
test_mod = Module()
54

@@ -96,6 +95,21 @@ end
9695
@test test_mod.f_global_method_capturing_local() == 2
9796
@test test_mod.f_global_method_capturing_local() == 3
9897

98+
# Closure with multiple methods depending on local variables
99+
f_closure_local_var_types = JuliaLowering.include_string(test_mod, """
100+
let T=Int, S=Float64
101+
function f_closure_local_var_types(::T)
102+
1
103+
end
104+
function f_closure_local_var_types(::S)
105+
1.0
106+
end
107+
end
108+
""")
109+
@test f_closure_local_var_types(2) == 1
110+
@test f_closure_local_var_types(2.0) == 1.0
111+
@test_throws MethodError f_closure_local_var_types("hi")
112+
99113
# Anon function syntax
100114
@test JuliaLowering.include_string(test_mod, """
101115
begin

0 commit comments

Comments
 (0)