Skip to content

Commit cdf8820

Browse files
committed
bpart: Turn on invalidation for guard->defined transitions
This addresses one of the last remaining TODOs of the binding partition work by performing invalidations when bindings transition from being undefined to being defined. This in particular finally addresses the performance issue that #54733 was intended to address (the issue was closed when we merged the mechanism, but it had so far been turned off). Turning on the invalidations themselves were always easy (a one line deletion). What is harder is making sure that the additional invalidations don't take extra time. To this end, we add two additional flags, one on Bindings, and one on methods. The flag on bindings tells us whether any method scan has so far found an implicit (not tracked in ->backedges) reference to this binding in any method body. The insight here is that most undefined bindings will not have been referenced previously (because they did not exist), so with a simple one bit saturating counter of the number of edges that would exist (if we did store them), we can fast-path the invalidation. However, this is not quite sufficient, as people often do things like: ``` foo() = bar() bar() = ... ... ``` which, without further improvements would incur an invalidation upon the definition of `bar`. The second insight (and what the flag on `Method` is for) is that we don't actually need to scan the method body until there is something to invalidate (i.e. until some `CodeInstance` has been created for the method). By defering the scanning until the first time that inference accesses the lowered code (with a flag to only do it once), we can easily avoid invalidation in the above scenario (while still invalidating if `foo()` was called before the definition of `bar`). As a further bonus, this also speeds up bootstrap by about 20% (putting us about back to where we used to be before the full bpart change) by skipping unnecessary invalidations even for non-guard transitions. Finally, this does not yet turn on inference's ability to infer guard partitions as `Union{}`. The reason for this is that such partitions can be replaced by backdated constants without invalidation. However, as soon as we remove the backdated const mechanism, this PR will allow us to turn on that change, further speeding up inference (by cutting off inference on branches known to error due to missing bindings).
1 parent abf4dcf commit cdf8820

File tree

9 files changed

+104
-28
lines changed

9 files changed

+104
-28
lines changed

Compiler/src/utilities.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,25 @@ function retrieve_code_info(mi::MethodInstance, world::UInt)
129129
else
130130
c = copy(src::CodeInfo)
131131
end
132+
if !def.did_scan_source
133+
# This scan must happen:
134+
# 1. After method definition
135+
# 2. Before any code instances that may have relied on information
136+
# from implicit GlobalRefs for this method are added to the cache
137+
# 3. Preferably while the IR is already uncompressed
138+
# 4. As late as possible, as early adding of the backedges may cause
139+
# spurious invalidations.
140+
#
141+
# At the moment we do so here, because
142+
# 1. It's reasonably late
143+
# 2. It has easy access to the uncompressed IR
144+
# 3. We necessarily pass through here before relying on any
145+
# information obtained from implicit GlobalRefs.
146+
#
147+
# However, the exact placement of this scan is not as important as
148+
# long as the above conditions are met.
149+
ccall(:jl_scan_method_source_now, Cvoid, (Any, Any), def, c)
150+
end
132151
end
133152
if c isa CodeInfo
134153
c.parent = mi

base/expr.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,10 @@ function make_atomic(order, ex)
13521352
op = :+
13531353
elseif ex.head === :(-=)
13541354
op = :-
1355+
elseif ex.head === :(|=)
1356+
op = :|
1357+
elseif ex.head === :(&=)
1358+
op = :&
13551359
elseif @isdefined string
13561360
shead = string(ex.head)
13571361
if endswith(shead, '=')

base/invalidation.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -115,28 +115,27 @@ end
115115

116116
function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core.BindingPartition, new_bpart::Union{Core.BindingPartition, Nothing}, new_max_world::UInt)
117117
gr = b.globalref
118-
if !is_some_guard(binding_kind(invalidated_bpart))
119-
# TODO: We may want to invalidate for these anyway, since they have performance implications
118+
if (b.flags & BINDING_FLAG_ANY_IMPLICIT_EDGES) != 0
120119
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
121120
for method in MethodList(mt)
122121
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
123122
end
124123
return true
125124
end
126-
if isdefined(b, :backedges)
127-
for edge in b.backedges
128-
if isa(edge, CodeInstance)
129-
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), edge, new_max_world)
130-
elseif isa(edge, Core.Binding)
131-
isdefined(edge, :partitions) || continue
132-
latest_bpart = edge.partitions
133-
latest_bpart.max_world == typemax(UInt) || continue
134-
is_some_imported(binding_kind(latest_bpart)) || continue
135-
partition_restriction(latest_bpart) === b || continue
136-
invalidate_code_for_globalref!(edge, latest_bpart, nothing, new_max_world)
137-
else
138-
invalidate_method_for_globalref!(gr, edge::Method, invalidated_bpart, new_max_world)
139-
end
125+
end
126+
if isdefined(b, :backedges)
127+
for edge in b.backedges
128+
if isa(edge, CodeInstance)
129+
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), edge, new_max_world)
130+
elseif isa(edge, Core.Binding)
131+
isdefined(edge, :partitions) || continue
132+
latest_bpart = edge.partitions
133+
latest_bpart.max_world == typemax(UInt) || continue
134+
is_some_imported(binding_kind(latest_bpart)) || continue
135+
partition_restriction(latest_bpart) === b || continue
136+
invalidate_code_for_globalref!(edge, latest_bpart, nothing, new_max_world)
137+
else
138+
invalidate_method_for_globalref!(gr, edge::Method, invalidated_bpart, new_max_world)
140139
end
141140
end
142141
end
@@ -166,7 +165,11 @@ gr_needs_backedge_in_module(gr::GlobalRef, mod::Module) = gr.mod !== mod
166165
# N.B.: This needs to match jl_maybe_add_binding_backedge
167166
function maybe_add_binding_backedge!(b::Core.Binding, edge::Union{Method, CodeInstance})
168167
method = isa(edge, Method) ? edge : edge.def.def::Method
169-
gr_needs_backedge_in_module(b.globalref, method.module) || return
168+
methmod = method.module
169+
if !gr_needs_backedge_in_module(b.globalref, methmod)
170+
@atomic :acquire_release b.flags |= BINDING_FLAG_ANY_IMPLICIT_EDGES
171+
return
172+
end
170173
if !isdefined(b, :backedges)
171174
b.backedges = Any[]
172175
end

base/runtime_internals.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ const PARTITION_FLAG_DEPWARN = 0x40
216216
const PARTITION_MASK_KIND = 0x0f
217217
const PARTITION_MASK_FLAG = 0xf0
218218

219+
const BINDING_FLAG_ANY_IMPLICIT_EDGES = 0x8
220+
219221
is_defined_const_binding(kind::UInt8) = (kind == PARTITION_KIND_CONST || kind == PARTITION_KIND_CONST_IMPORT || kind == PARTITION_KIND_BACKDATED_CONST)
220222
is_some_const_binding(kind::UInt8) = (is_defined_const_binding(kind) || kind == PARTITION_KIND_UNDEF_CONST)
221223
is_some_imported(kind::UInt8) = (kind == PARTITION_KIND_IMPLICIT || kind == PARTITION_KIND_EXPLICIT || kind == PARTITION_KIND_IMPORTED)

src/jltypes.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3275,7 +3275,7 @@ void jl_init_types(void) JL_GC_DISABLED
32753275
jl_svec(5, jl_any_type/*jl_globalref_type*/, jl_any_type, jl_binding_partition_type,
32763276
jl_any_type, jl_uint8_type),
32773277
jl_emptysvec, 0, 1, 0);
3278-
const static uint32_t binding_atomicfields[] = { 0x0005 }; // Set fields 2, 3 as atomic
3278+
const static uint32_t binding_atomicfields[] = { 0x0016 }; // Set fields 2, 3, 5 as atomic
32793279
jl_binding_type->name->atomicfields = binding_atomicfields;
32803280
const static uint32_t binding_constfields[] = { 0x0001 }; // Set fields 1 as constant
32813281
jl_binding_type->name->constfields = binding_constfields;
@@ -3539,7 +3539,7 @@ void jl_init_types(void) JL_GC_DISABLED
35393539
jl_method_type =
35403540
jl_new_datatype(jl_symbol("Method"), core,
35413541
jl_any_type, jl_emptysvec,
3542-
jl_perm_symsvec(31,
3542+
jl_perm_symsvec(32,
35433543
"name",
35443544
"module",
35453545
"file",
@@ -3568,10 +3568,11 @@ void jl_init_types(void) JL_GC_DISABLED
35683568
"isva",
35693569
"is_for_opaque_closure",
35703570
"nospecializeinfer",
3571+
"did_scan_source",
35713572
"constprop",
35723573
"max_varargs",
35733574
"purity"),
3574-
jl_svec(31,
3575+
jl_svec(32,
35753576
jl_symbol_type,
35763577
jl_module_type,
35773578
jl_symbol_type,
@@ -3600,6 +3601,7 @@ void jl_init_types(void) JL_GC_DISABLED
36003601
jl_bool_type,
36013602
jl_bool_type,
36023603
jl_bool_type,
3604+
jl_bool_type,
36033605
jl_uint8_type,
36043606
jl_uint8_type,
36053607
jl_uint16_type),

src/julia.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ typedef struct _jl_method_t {
375375
uint8_t isva;
376376
uint8_t is_for_opaque_closure;
377377
uint8_t nospecializeinfer;
378+
_Atomic(uint8_t) did_scan_source;
379+
378380
// uint8 settings
379381
uint8_t constprop; // 0x00 = use heuristic; 0x01 = aggressive; 0x02 = none
380382
uint8_t max_varargs; // 0xFF = use heuristic; otherwise, max # of args to expand
@@ -751,7 +753,10 @@ enum jl_binding_flags {
751753
BINDING_FLAG_DID_PRINT_BACKDATE_ADMONITION = 0x1,
752754
BINDING_FLAG_DID_PRINT_IMPLICIT_IMPORT_ADMONITION = 0x2,
753755
// `export` is tracked in partitions, but sets this as well
754-
BINDING_FLAG_PUBLICP = 0x4
756+
BINDING_FLAG_PUBLICP = 0x4,
757+
// Set if any methods defined in this module implicitly reference
758+
// this binding. If not, invalidation is optimized.
759+
BINDING_FLAG_ANY_IMPLICIT_EDGES = 0x8
755760
};
756761

757762
typedef struct _jl_binding_t {

src/method.c

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,28 @@ static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at)
3939
}
4040
}
4141

42+
JL_DLLEXPORT void jl_scan_method_source_now(jl_method_t *m, jl_value_t *src)
43+
{
44+
if (!jl_atomic_load_relaxed(&m->did_scan_source)) {
45+
jl_code_info_t *code = NULL;
46+
JL_GC_PUSH1(&code);
47+
if (!jl_is_code_info(src))
48+
code = jl_uncompress_ir(m, NULL, src);
49+
else
50+
code = (jl_code_info_t*)src;
51+
jl_array_t *stmts = code->code;
52+
size_t i, l = jl_array_nrows(stmts);
53+
for (i = 0; i < l; i++) {
54+
jl_value_t *stmt = jl_array_ptr_ref(stmts, i);
55+
if (jl_is_globalref(stmt)) {
56+
jl_maybe_add_binding_backedge((jl_globalref_t*)stmt, m->module, (jl_value_t*)m);
57+
}
58+
}
59+
jl_atomic_store_relaxed(&m->did_scan_source, 1);
60+
JL_GC_POP();
61+
}
62+
}
63+
4264
// Resolve references to non-locally-defined variables to become references to global
4365
// variables in `module` (unless the rvalue is one of the type parameters in `sparam_vals`).
4466
static jl_value_t *resolve_definition_effects(jl_value_t *expr, jl_module_t *module, jl_svec_t *sparam_vals, jl_value_t *binding_edge,
@@ -47,10 +69,7 @@ static jl_value_t *resolve_definition_effects(jl_value_t *expr, jl_module_t *mod
4769
if (jl_is_symbol(expr)) {
4870
jl_error("Found raw symbol in code returned from lowering. Expected all symbols to have been resolved to GlobalRef or slots.");
4971
}
50-
if (jl_is_globalref(expr)) {
51-
jl_maybe_add_binding_backedge((jl_globalref_t*)expr, module, binding_edge);
52-
return expr;
53-
}
72+
5473
if (!jl_is_expr(expr)) {
5574
return expr;
5675
}
@@ -973,6 +992,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
973992
jl_atomic_store_relaxed(&m->deleted_world, 1);
974993
m->is_for_opaque_closure = 0;
975994
m->nospecializeinfer = 0;
995+
jl_atomic_store_relaxed(&m->did_scan_source, 0);
976996
m->constprop = 0;
977997
m->purity.bits = 0;
978998
m->max_varargs = UINT8_MAX;

src/module.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,15 +1345,16 @@ JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t
13451345
{
13461346
if (!edge)
13471347
return;
1348+
jl_binding_t *b = gr->binding;
1349+
if (!b)
1350+
b = jl_get_module_binding(gr->mod, gr->name, 1);
13481351
// N.B.: The logic for evaluating whether a backedge is required must
13491352
// match the invalidation logic.
13501353
if (gr->mod == defining_module) {
13511354
// No backedge required - invalidation will forward scan
1355+
jl_atomic_fetch_or(&b->flags, BINDING_FLAG_ANY_IMPLICIT_EDGES);
13521356
return;
13531357
}
1354-
jl_binding_t *b = gr->binding;
1355-
if (!b)
1356-
b = jl_get_module_binding(gr->mod, gr->name, 1);
13571358
jl_add_binding_backedge(b, edge);
13581359
}
13591360

test/rebinding.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,23 @@ module RangeMerge
298298

299299
@test !contains(get_llvm(f, Tuple{}), "jl_get_binding_value")
300300
end
301+
302+
# Test that we invalidate for undefined -> defined transitions (#54733)
303+
module UndefinedTransitions
304+
using Test
305+
function foo54733()
306+
for i = 1:1_000_000_000
307+
bar54733(i)
308+
end
309+
return 1
310+
end
311+
@test_throws UndefVarError foo54733()
312+
let ci = first(methods(foo54733)).specializations.cache
313+
@test !Base.Compiler.is_nothrow(Base.Compiler.decode_effects(ci.ipo_purity_bits))
314+
end
315+
bar54733(x) = 3x
316+
@test foo54733() === 1
317+
let ci = first(methods(foo54733)).specializations.cache
318+
@test Base.Compiler.is_nothrow(Base.Compiler.decode_effects(ci.ipo_purity_bits))
319+
end
320+
end

0 commit comments

Comments
 (0)