Skip to content

Commit c7c7bde

Browse files
committed
Ensure closures capture the types of typed variables
1 parent 4c3494f commit c7c7bde

File tree

7 files changed

+179
-93
lines changed

7 files changed

+179
-93
lines changed

src/closure_conversion.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,29 +90,26 @@ end
9090
# global and for converting the return value of a function call to the declared
9191
# return type.
9292
function convert_for_type_decl(ctx, srcref, ex, type, do_typeassert)
93-
# Require that the caller make `type` "simple", for now (can generalize
94-
# later if necessary)
95-
kt = kind(type)
96-
@assert (kt == K"Identifier" || kt == K"BindingId" || is_literal(kt))
9793
# Use a slot to permit union-splitting this in inference
9894
tmp = new_local_binding(ctx, srcref, "tmp", is_always_defined=true)
9995

10096
@ast ctx srcref [K"block"
97+
type_tmp := type
10198
# [K"=" type_ssa renumber_assigned_ssavalues(type)]
10299
[K"=" tmp ex]
103100
[K"if"
104-
[K"call" "isa"::K"core" tmp type]
101+
[K"call" "isa"::K"core" tmp type_tmp]
105102
"nothing"::K"core"
106103
[K"="
107104
tmp
108105
if do_typeassert
109106
[K"call"
110107
"typeassert"::K"core"
111-
[K"call" "convert"::K"top" type tmp]
112-
type
108+
[K"call" "convert"::K"top" type_tmp tmp]
109+
type_tmp
113110
]
114111
else
115-
[K"call" "convert"::K"top" type tmp]
112+
[K"call" "convert"::K"top" type_tmp tmp]
116113
end
117114
]
118115
]

src/scope_analysis.jl

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,14 @@ function _resolve_scopes(ctx, ex::SyntaxTree)
433433
throw(LoweringError(ex, "type declarations for global variables must be at top level, not inside a function"))
434434
end
435435
end
436+
id = ex_out[1]
437+
if kind(id) != K"Placeholder"
438+
binfo = lookup_binding(ctx, id)
439+
if !isnothing(binfo.type)
440+
throw(LoweringError(ex, "multiple type declarations found for `$(binfo.name)`"))
441+
end
442+
update_binding!(ctx, id; type=ex_out[2])
443+
end
436444
ex_out
437445
elseif k == K"always_defined"
438446
id = lookup_var(ctx, NameKey(ex[1]))
@@ -624,14 +632,22 @@ function analyze_variables!(ctx, ex)
624632
k = kind(ex)
625633
if k == K"BindingId"
626634
if has_lambda_binding(ctx, ex)
627-
# FIXME: Move this after closure conversion so that we don't need
635+
# TODO: Move this after closure conversion so that we don't need
628636
# to model the closure conversion transformations here.
629637
update_lambda_binding!(ctx, ex, is_read=true)
638+
else
639+
binfo = lookup_binding(ctx, ex.var_id)
640+
if !binfo.is_ssa && binfo.kind != :global
641+
# The type of typed locals is invisible in the previous pass,
642+
# but is filled in here.
643+
init_lambda_binding(ctx.lambda_bindings, ex.var_id, is_captured=true, is_read=true)
644+
update_binding!(ctx, ex, is_captured=true)
645+
end
630646
end
631647
elseif is_leaf(ex) || is_quoted(ex)
632648
return
633649
elseif k == K"local" || k == K"global"
634-
# Uses of bindings which don't count as uses.
650+
# Presence of BindingId within local/global is ignored.
635651
return
636652
elseif k == K"="
637653
lhs = ex[1]
@@ -640,6 +656,12 @@ function analyze_variables!(ctx, ex)
640656
if has_lambda_binding(ctx, lhs)
641657
update_lambda_binding!(ctx, lhs, is_assigned=true)
642658
end
659+
lhs_binfo = lookup_binding(ctx, lhs)
660+
if !isnothing(lhs_binfo.type)
661+
# Assignments introduce a variable's type later during closure
662+
# conversion, but we must model that explicitly here.
663+
analyze_variables!(ctx, lhs_binfo.type)
664+
end
643665
end
644666
analyze_variables!(ctx, ex[2])
645667
elseif k == K"function_decl"
@@ -655,17 +677,6 @@ function analyze_variables!(ctx, ex)
655677
if kind(ex[1]) != K"BindingId" || lookup_binding(ctx, ex[1]).kind !== :local
656678
analyze_variables!(ctx, ex[1])
657679
end
658-
elseif k == K"decl"
659-
@chk numchildren(ex) == 2
660-
id = ex[1]
661-
if kind(id) != K"Placeholder"
662-
binfo = lookup_binding(ctx, id)
663-
if !isnothing(binfo.type)
664-
throw(LoweringError(ex, "multiple type declarations found for `$(binfo.name)`"))
665-
end
666-
update_binding!(ctx, id; type=ex[2])
667-
end
668-
analyze_variables!(ctx, ex[2])
669680
elseif k == K"const"
670681
id = ex[1]
671682
if lookup_binding(ctx, id).kind == :local
@@ -677,7 +688,7 @@ function analyze_variables!(ctx, ex)
677688
if kind(name) == K"BindingId"
678689
id = name.var_id
679690
if has_lambda_binding(ctx, id)
680-
# FIXME: Move this after closure conversion so that we don't need
691+
# TODO: Move this after closure conversion so that we don't need
681692
# to model the closure conversion transformations.
682693
update_lambda_binding!(ctx, id, is_called=true)
683694
end
@@ -710,9 +721,10 @@ function analyze_variables!(ctx, ex)
710721
end
711722
ctx2 = VariableAnalysisContext(ctx.graph, ctx.bindings, ctx.mod, lambda_bindings,
712723
ctx.method_def_stack, ctx.closure_bindings)
713-
# Add any captured bindings to the enclosing lambda, if necessary.
724+
foreach(e->analyze_variables!(ctx2, e), ex[3:end]) # body & return type
714725
for (id,lbinfo) in pairs(lambda_bindings.bindings)
715726
if lbinfo.is_captured
727+
# Add any captured bindings to the enclosing lambda, if necessary.
716728
outer_lbinfo = lookup_lambda_binding(ctx.lambda_bindings, id)
717729
if isnothing(outer_lbinfo)
718730
# Inner lambda captures a variable. If it's not yet present
@@ -723,9 +735,6 @@ function analyze_variables!(ctx, ex)
723735
end
724736
end
725737
end
726-
727-
# TODO: Types of any assigned captured vars will also be used and might be captured.
728-
foreach(e->analyze_variables!(ctx2, e), ex[3:end])
729738
else
730739
foreach(e->analyze_variables!(ctx, e), children(ex))
731740
end

test/assignments_ir.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,19 @@ end
8181
1 (newvar slot₁/x)
8282
2 TestMod.f
8383
3 (call %₂)
84-
4 (= slot₂/tmp %₃)
85-
5 slot₂/tmp
86-
6 TestMod.T
87-
7 (call core.isa %%)
84+
4 TestMod.T
85+
5 (= slot₂/tmp %₃)
86+
6 slot₂/tmp
87+
7 (call core.isa %%)
8888
8 (gotoifnot %₇ label₁₀)
89-
9 (goto label₁₅)
90-
10 TestMod.T
91-
11 slot₂/tmp
92-
12 (call top.convert %₁₀ %₁₁)
93-
13 TestMod.T
94-
14 (= slot₂/tmp (call core.typeassert %₁₂ %₁₃))
95-
15 slot₂/tmp
96-
16 (= slot₁/x %₁₅)
97-
17 slot₁/x
98-
18 (return %₁₇)
89+
9 (goto label₁₃)
90+
10 slot₂/tmp
91+
11 (call top.convert %%₁₀)
92+
12 (= slot₂/tmp (call core.typeassert %₁₁ %₄))
93+
13 slot₂/tmp
94+
14 (= slot₁/x %₁₃)
95+
15 slot₁/x
96+
16 (return %₁₅)
9997

10098
########################################
10199
# "complex lhs" of `::T` => type-assert, not decl

test/closures.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,62 @@ begin
126126
end
127127
""") === (1,2,3)
128128

129+
# Closure with return type must capture the return type
130+
@test JuliaLowering.include_string(test_mod, """
131+
let T = Int
132+
function f_captured_return_type()::T
133+
2.0
134+
end
135+
f_captured_return_type()
136+
end
137+
""") === 2
138+
139+
# Capturing a typed local
140+
@test JuliaLowering.include_string(test_mod, """
141+
let T = Int
142+
x::T = 1.0
143+
function f_captured_typed_local()
144+
x = 2.0
145+
end
146+
f_captured_typed_local()
147+
x
148+
end
149+
""") === 2
150+
151+
# Capturing a typed local where the type is a nontrivial expression
152+
@test begin
153+
res = JuliaLowering.include_string(test_mod, """
154+
let T = Int, V=Vector
155+
x::V{T} = [1,2]
156+
function f_captured_typed_local_composite()
157+
x = [100.0, 200.0]
158+
end
159+
f_captured_typed_local_composite()
160+
x
161+
end
162+
""")
163+
res == [100, 200] && eltype(res) == Int
164+
end
165+
166+
# Evil case where we mutate `T` which is the type of `x`, such that x is
167+
# eventually set to a Float64.
168+
#
169+
# Completely dynamic types for variables should be disallowed somehow?? For
170+
# example, by emitting the expression computing the type of `x` alongside the
171+
# newvar node. However, for now we verify that this potentially evil behavior
172+
# is compatible with the existing implementation :)
173+
@test JuliaLowering.include_string(test_mod, """
174+
let T = Int
175+
x::T = 1.0
176+
function f_captured_mutating_typed_local()
177+
x = 2
178+
end
179+
T = Float64
180+
f_captured_mutating_typed_local()
181+
x
182+
end
183+
""") === 2.0
184+
129185
# Anon function syntax
130186
@test JuliaLowering.include_string(test_mod, """
131187
begin

test/closures_ir.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,40 @@ end
676676
46 slot₃/f_kw_closure
677677
47 (return %₄₆)
678678

679+
########################################
680+
# Closure capturing a typed local must also capture the type expression
681+
# [method_filter: #f_captured_typed_local##0]
682+
let T=Blah
683+
x::T = 1.0
684+
function f_captured_typed_local()
685+
x = 2.0
686+
end
687+
f_captured_typed_local()
688+
x
689+
end
690+
#---------------------
691+
slots: [slot₁/#self#(!read) slot₂/T(!read) slot₃/tmp(!read)]
692+
1 2.0
693+
2 (call core.getfield slot₁/#self# :x)
694+
3 (call core.getfield slot₁/#self# :T)
695+
4 (call core.isdefined %:contents)
696+
5 (gotoifnot %₄ label₇)
697+
6 (goto label₉)
698+
7 (newvar slot₂/T)
699+
8 slot₂/T
700+
9 (call core.getfield %:contents)
701+
10 (= slot₃/tmp %₁)
702+
11 slot₃/tmp
703+
12 (call core.isa %₁₁ %₉)
704+
13 (gotoifnot %₁₂ label₁₅)
705+
14 (goto label₁₈)
706+
15 slot₃/tmp
707+
16 (call top.convert %%₁₅)
708+
17 (= slot₃/tmp (call core.typeassert %₁₆ %₉))
709+
18 slot₃/tmp
710+
19 (call core.setfield! %:contents %₁₈)
711+
20 (return %₁)
712+
679713
########################################
680714
# Error: Closure outside any top level context
681715
# (Should only happen in a user-visible way when lowering code emitted

test/decls_ir.jl

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,18 @@ local x::T = 1
44
#---------------------
55
1 (newvar slot₁/x)
66
2 1
7-
3 (= slot₂/tmp %₂)
8-
4 slot₂/tmp
9-
5 TestMod.T
10-
6 (call core.isa %%)
7+
3 TestMod.T
8+
4 (= slot₂/tmp %₂)
9+
5 slot₂/tmp
10+
6 (call core.isa %%)
1111
7 (gotoifnot %₆ label₉)
12-
8 (goto label₁₄)
13-
9 TestMod.T
14-
10 slot₂/tmp
15-
11 (call top.convert %%₁₀)
16-
12 TestMod.T
17-
13 (= slot₂/tmp (call core.typeassert %₁₁ %₁₂))
18-
14 slot₂/tmp
19-
15 (= slot₁/x %₁₄)
20-
16 (return %₂)
12+
8 (goto label₁₂)
13+
9 slot₂/tmp
14+
10 (call top.convert %%₉)
15+
11 (= slot₂/tmp (call core.typeassert %₁₀ %₃))
16+
12 slot₂/tmp
17+
13 (= slot₁/x %₁₂)
18+
14 (return %₂)
2119

2220
########################################
2321
# const
@@ -133,35 +131,31 @@ end
133131
8 --- method core.nothing %
134132
slots: [slot₁/#self#(!read) slot₂/x slot₃/tmp(!read) slot₄/tmp(!read)]
135133
1 1
136-
2 (= slot₃/tmp %₁)
137-
3 slot₃/tmp
138-
4 TestMod.Int
139-
5 (call core.isa %%)
134+
2 TestMod.Int
135+
3 (= slot₃/tmp %₁)
136+
4 slot₃/tmp
137+
5 (call core.isa %%)
140138
6 (gotoifnot %₅ label₈)
141-
7 (goto label₁)
142-
8 TestMod.Int
143-
9 slot₃/tmp
144-
10 (call top.convert %%)
145-
11 TestMod.Int
146-
12 (= slot/tmp (call core.typeassert %₁₀ %₁₁))
147-
13 slot₃/tmp
148-
14 (= slot₂/x %₁₃)
149-
15 2.0
150-
16 (= slot₄/tmp %₁₅)
151-
17 slot₄/tmp
152-
18 TestMod.Int
153-
19 (call core.isa %₁₇ %₁₈)
154-
20 (gotoifnot %₁₉ label₂₂)
155-
21 (goto label₂₇)
156-
22 TestMod.Int
139+
7 (goto label₁)
140+
8 slot₃/tmp
141+
9 (call top.convert %%₈)
142+
10 (= slot₃/tmp (call core.typeassert %%₂))
143+
11 slot₃/tmp
144+
12 (= slot/x %₁₁)
145+
13 2.0
146+
14 TestMod.Int
147+
15 (= slot₄/tmp %₁₃)
148+
16 slot₄/tmp
149+
17 (call core.isa %₁₆ %₁₄)
150+
18 (gotoifnot %₁₇ label₂₀)
151+
19 (goto label₂₃)
152+
20 slot₄/tmp
153+
21 (call top.convert %₁₄ %₂₀)
154+
22 (= slot₄/tmp (call core.typeassert %₂₁ %₁₄))
157155
23 slot₄/tmp
158-
24 (call top.convert %₂₂ %₂₃)
159-
25 TestMod.Int
160-
26 (= slot₄/tmp (call core.typeassert %₂₄ %₂₅))
161-
27 slot₄/tmp
162-
28 (= slot₂/x %₂₇)
163-
29 slot₂/x
164-
30 (return %₂₉)
156+
24 (= slot₂/x %₂₃)
157+
25 slot₂/x
158+
26 (return %₂₅)
165159
9 TestMod.f
166160
10 (return %₉)
167161

test/destructuring_ir.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -341,20 +341,18 @@ end
341341
1 (newvar slot₁/x)
342342
2 TestMod.rhs
343343
3 (call top.getproperty %:x)
344-
4 (= slot₂/tmp %₃)
345-
5 slot₂/tmp
346-
6 TestMod.T
347-
7 (call core.isa %%)
344+
4 TestMod.T
345+
5 (= slot₂/tmp %₃)
346+
6 slot₂/tmp
347+
7 (call core.isa %%)
348348
8 (gotoifnot %₇ label₁₀)
349-
9 (goto label₁₅)
350-
10 TestMod.T
351-
11 slot₂/tmp
352-
12 (call top.convert %₁₀ %₁₁)
353-
13 TestMod.T
354-
14 (= slot₂/tmp (call core.typeassert %₁₂ %₁₃))
355-
15 slot₂/tmp
356-
16 (= slot₁/x %₁₅)
357-
17 (return %₂)
349+
9 (goto label₁₃)
350+
10 slot₂/tmp
351+
11 (call top.convert %%₁₀)
352+
12 (= slot₂/tmp (call core.typeassert %₁₁ %₄))
353+
13 slot₂/tmp
354+
14 (= slot₁/x %₁₃)
355+
15 (return %₂)
358356

359357
########################################
360358
# Error: Property destructuring with frankentuple

0 commit comments

Comments
 (0)