Skip to content

Commit a8a1c3b

Browse files
committed
Align module base between invalidation and edge tracking
Our implicit edge tracking for bindings does not explicitly store any edges for bindings in the *current* module. The idea behind this is that this is a good time-space tradeoff for validation, because substantially all binding references in a module will be to its defining module, while the total number of methods within a module is limited and substantially smaller than the total number of methods in the entire system. However, we have an issue where the code that stores these edges and the invalidation code disagree on which module is the *current* one. The edge storing code was using the module in which the method was defined, while the invalidation code was using the one in which the MethodTable is defined. With these being misaligned, we can miss necessary invalidations. Both options are in principle possible, but I think the former is better, because the module in which the method is defined is also the module that we are likely to have a lot of references to (since they get referenced implicitly by just writing symbols in the code). However, this presents a problem: We don't actually have a way to iterate all the methods defined in a particular module, without just doing the brute force thing of scanning all methods and filtering. To address this, build on the deferred scanning code added in #57615 to also add any scanned modules to an explicit list in `Module`. This costs some space, but only proportional to the number of defined methods, (and thus proportional to the written source code). Note that we don't actually observe any issues in the test suite on master due to this bug. However, this is because we are grossly over-invalidating, which hides the missing invalidations from this issue (#57617).
1 parent e7efe42 commit a8a1c3b

File tree

9 files changed

+74
-35
lines changed

9 files changed

+74
-35
lines changed

Compiler/src/utilities.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function retrieve_code_info(mi::MethodInstance, world::UInt)
129129
else
130130
c = copy(src::CodeInfo)
131131
end
132-
if !def.did_scan_source
132+
if (def.did_scan_source & 0x1) == 0x0
133133
# This scan must happen:
134134
# 1. After method definition
135135
# 2. Before any code instances that may have relied on information

base/invalidation.jl

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,10 @@ function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core
136136

137137
if need_to_invalidate_code
138138
if (b.flags & BINDING_FLAG_ANY_IMPLICIT_EDGES) != 0
139-
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
140-
for method in MethodList(mt)
141-
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
142-
end
143-
return true
139+
nmethods = ccall(:jl_module_scanned_methods_length, Csize_t, (Any,), gr.mod)
140+
for i = 1:nmethods
141+
method = ccall(:jl_module_scanned_methods_getindex, Any, (Any, Csize_t), gr.mod, i)::Method
142+
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
144143
end
145144
end
146145
if isdefined(b, :backedges)
@@ -166,7 +165,7 @@ function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core
166165
# have a binding that is affected by this change.
167166
usings_backedges = ccall(:jl_get_module_usings_backedges, Any, (Any,), gr.mod)
168167
if usings_backedges !== nothing
169-
for user in usings_backedges::Vector{Any}
168+
for user::Module in usings_backedges::Vector{Any}
170169
user_binding = ccall(:jl_get_module_binding_or_nothing, Any, (Any, Any), user, gr.name)
171170
user_binding === nothing && continue
172171
isdefined(user_binding, :partitions) || continue
@@ -186,21 +185,10 @@ end
186185
invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_bpart::Core.BindingPartition, new_max_world::UInt) =
187186
invalidate_code_for_globalref!(convert(Core.Binding, gr), invalidated_bpart, new_bpart, new_max_world)
188187

189-
gr_needs_backedge_in_module(gr::GlobalRef, mod::Module) = gr.mod !== mod
190-
191-
# N.B.: This needs to match jl_maybe_add_binding_backedge
192188
function maybe_add_binding_backedge!(b::Core.Binding, edge::Union{Method, CodeInstance})
193-
method = isa(edge, Method) ? edge : edge.def.def::Method
194-
methmod = method.module
195-
if !gr_needs_backedge_in_module(b.globalref, methmod)
196-
@atomic :acquire_release b.flags |= BINDING_FLAG_ANY_IMPLICIT_EDGES
197-
return
198-
end
199-
if !isdefined(b, :backedges)
200-
b.backedges = Any[]
201-
end
202-
!isempty(b.backedges) && b.backedges[end] === edge && return
203-
push!(b.backedges, edge)
189+
meth = isa(edge, Method) ? edge : Compiler.get_ci_mi(edge).def
190+
ccall(:jl_maybe_add_binding_backedge, Cint, (Any, Any, Any), b, edge, meth)
191+
return nothing
204192
end
205193

206194
function binding_was_invalidated(b::Core.Binding)

src/gc-stock.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,9 @@ STATIC_INLINE void gc_mark_module_binding(jl_ptls_t ptls, jl_module_t *mb_parent
21472147
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->usings_backedges);
21482148
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->usings_backedges, &nptr);
21492149
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->usings_backedges);
2150+
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->scanned_methods);
2151+
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->scanned_methods, &nptr);
2152+
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->scanned_methods);
21502153
size_t nusings = module_usings_length(mb_parent);
21512154
if (nusings > 0) {
21522155
// this is only necessary because bindings for "using" modules

src/jltypes.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3601,7 +3601,7 @@ void jl_init_types(void) JL_GC_DISABLED
36013601
jl_bool_type,
36023602
jl_bool_type,
36033603
jl_bool_type,
3604-
jl_bool_type,
3604+
jl_uint8_type,
36053605
jl_uint8_type,
36063606
jl_uint8_type,
36073607
jl_uint16_type),

src/julia.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ typedef struct _jl_method_t {
375375
uint8_t isva;
376376
uint8_t is_for_opaque_closure;
377377
uint8_t nospecializeinfer;
378+
// bit flags, 0x01 = scanned
379+
// 0x02 = added to module scanned list (either from scanning or inference edge)
378380
_Atomic(uint8_t) did_scan_source;
379381

380382
// uint8 settings
@@ -782,6 +784,7 @@ typedef struct _jl_module_t {
782784
jl_sym_t *file;
783785
int32_t line;
784786
jl_value_t *usings_backedges;
787+
jl_value_t *scanned_methods;
785788
// hidden fields:
786789
arraylist_t usings; /* arraylist of struct jl_module_using */ // modules with all bindings potentially imported
787790
jl_uuid_t build_id;
@@ -2059,6 +2062,7 @@ JL_DLLEXPORT int jl_get_module_infer(jl_module_t *m);
20592062
JL_DLLEXPORT void jl_set_module_max_methods(jl_module_t *self, int value);
20602063
JL_DLLEXPORT int jl_get_module_max_methods(jl_module_t *m);
20612064
JL_DLLEXPORT jl_value_t *jl_get_module_usings_backedges(jl_module_t *m);
2065+
JL_DLLEXPORT jl_value_t *jl_get_module_scanned_methods(jl_module_t *m);
20622066
JL_DLLEXPORT jl_value_t *jl_get_module_binding_or_nothing(jl_module_t *m, jl_sym_t *s);
20632067

20642068
// get binding for reading

src/julia_internal.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ jl_code_info_t *jl_new_code_info_from_ir(jl_expr_t *ast);
722722
JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void);
723723
JL_DLLEXPORT void jl_resolve_definition_effects_in_ir(jl_array_t *stmts, jl_module_t *m, jl_svec_t *sparam_vals, jl_value_t *binding_edge,
724724
int binding_effects);
725-
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge);
725+
JL_DLLEXPORT int jl_maybe_add_binding_backedge(jl_binding_t *b, jl_value_t *edge, jl_method_t *in_method);
726726
JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge);
727727

728728
int get_next_edge(jl_array_t *list, int i, jl_value_t** invokesig, jl_code_instance_t **caller) JL_NOTSAFEPOINT;
@@ -878,6 +878,7 @@ STATIC_INLINE size_t module_usings_max(jl_module_t *m) JL_NOTSAFEPOINT {
878878
}
879879

880880
JL_DLLEXPORT jl_sym_t *jl_module_name(jl_module_t *m) JL_NOTSAFEPOINT;
881+
void jl_add_scanned_method(jl_module_t *m, jl_method_t *meth);
881882
jl_value_t *jl_eval_global_var(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *e);
882883
jl_value_t *jl_interpret_opaque_closure(jl_opaque_closure_t *clos, jl_value_t **args, size_t nargs);
883884
jl_value_t *jl_interpret_toplevel_thunk(jl_module_t *m, jl_code_info_t *src);

src/method.c

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,20 @@ static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at)
3939
}
4040
}
4141

42+
void jl_add_scanned_method(jl_module_t *m, jl_method_t *meth)
43+
{
44+
JL_LOCK(&m->lock);
45+
if (m->scanned_methods == jl_nothing) {
46+
m->scanned_methods = (jl_value_t*)jl_alloc_vec_any(0);
47+
jl_gc_wb(m, m->scanned_methods);
48+
}
49+
jl_array_ptr_1d_push((jl_array_t*)m->scanned_methods, (jl_value_t*)meth);
50+
JL_UNLOCK(&m->lock);
51+
}
52+
4253
JL_DLLEXPORT void jl_scan_method_source_now(jl_method_t *m, jl_value_t *src)
4354
{
44-
if (!jl_atomic_load_relaxed(&m->did_scan_source)) {
55+
if (!jl_atomic_fetch_or(&m->did_scan_source, 1)) {
4556
jl_code_info_t *code = NULL;
4657
JL_GC_PUSH1(&code);
4758
if (!jl_is_code_info(src))
@@ -50,13 +61,19 @@ JL_DLLEXPORT void jl_scan_method_source_now(jl_method_t *m, jl_value_t *src)
5061
code = (jl_code_info_t*)src;
5162
jl_array_t *stmts = code->code;
5263
size_t i, l = jl_array_nrows(stmts);
64+
int any_implicit = 0;
5365
for (i = 0; i < l; i++) {
5466
jl_value_t *stmt = jl_array_ptr_ref(stmts, i);
5567
if (jl_is_globalref(stmt)) {
56-
jl_maybe_add_binding_backedge((jl_globalref_t*)stmt, m->module, (jl_value_t*)m);
68+
jl_globalref_t *gr = (jl_globalref_t*)stmt;
69+
jl_binding_t *b = gr->binding;
70+
if (!b)
71+
b = jl_get_module_binding(gr->mod, gr->name, 1);
72+
any_implicit |= jl_maybe_add_binding_backedge(b, (jl_value_t*)m, m);
5773
}
5874
}
59-
jl_atomic_store_relaxed(&m->did_scan_source, 1);
75+
if (any_implicit && !(jl_atomic_fetch_or(&m->did_scan_source, 0x2) & 0x2))
76+
jl_add_scanned_method(m->module, m);
6077
JL_GC_POP();
6178
}
6279
}

src/module.c

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ STATIC_INLINE jl_binding_partition_t *jl_get_binding_partition_(jl_binding_t *b
209209
if (!new_bpart)
210210
new_bpart = new_binding_partition();
211211
jl_atomic_store_relaxed(&new_bpart->next, bpart);
212-
jl_gc_wb(new_bpart, bpart); // Not fresh the second time around the loop
212+
if (bpart)
213+
jl_gc_wb(new_bpart, bpart); // Not fresh the second time around the loop
213214
new_bpart->min_world = bpart ? jl_atomic_load_relaxed(&bpart->max_world) + 1 : 0;
214215
jl_atomic_store_relaxed(&new_bpart->max_world, max_world);
215216
JL_GC_PROMISE_ROOTED(new_bpart); // TODO: Analyzer doesn't understand MAYBE_UNROOTED properly
@@ -319,6 +320,7 @@ JL_DLLEXPORT jl_module_t *jl_new_module__(jl_sym_t *name, jl_module_t *parent)
319320
m->build_id.hi = ~(uint64_t)0;
320321
jl_atomic_store_relaxed(&m->counter, 1);
321322
m->usings_backedges = jl_nothing;
323+
m->scanned_methods = jl_nothing;
322324
m->nospecialize = 0;
323325
m->optlevel = -1;
324326
m->compile = -1;
@@ -1163,6 +1165,25 @@ JL_DLLEXPORT jl_value_t *jl_get_module_usings_backedges(jl_module_t *m)
11631165
return m->usings_backedges;
11641166
}
11651167

1168+
JL_DLLEXPORT size_t jl_module_scanned_methods_length(jl_module_t *m)
1169+
{
1170+
JL_LOCK(&m->lock);
1171+
size_t len = 0;
1172+
if (m->scanned_methods != jl_nothing)
1173+
len = jl_array_len(m->scanned_methods);
1174+
JL_UNLOCK(&m->lock);
1175+
return len;
1176+
}
1177+
1178+
JL_DLLEXPORT jl_value_t *jl_module_scanned_methods_getindex(jl_module_t *m, size_t i)
1179+
{
1180+
JL_LOCK(&m->lock);
1181+
assert(m->scanned_methods != jl_nothing);
1182+
jl_value_t *ret = jl_array_ptr_ref(m->scanned_methods, i-1);
1183+
JL_UNLOCK(&m->lock);
1184+
return ret;
1185+
}
1186+
11661187
JL_DLLEXPORT jl_value_t *jl_get_module_binding_or_nothing(jl_module_t *m, jl_sym_t *s)
11671188
{
11681189
jl_binding_t *b = jl_get_module_binding(m, s, 0);
@@ -1369,21 +1390,22 @@ JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge)
13691390

13701391
// Called for all GlobalRefs found in lowered code. Adds backedges for cross-module
13711392
// GlobalRefs.
1372-
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge)
1393+
JL_DLLEXPORT int jl_maybe_add_binding_backedge(jl_binding_t *b, jl_value_t *edge, jl_method_t *for_method)
13731394
{
13741395
if (!edge)
1375-
return;
1376-
jl_binding_t *b = gr->binding;
1377-
if (!b)
1378-
b = jl_get_module_binding(gr->mod, gr->name, 1);
1396+
return 0;
1397+
jl_module_t *defining_module = for_method->module;
13791398
// N.B.: The logic for evaluating whether a backedge is required must
13801399
// match the invalidation logic.
1381-
if (gr->mod == defining_module) {
1400+
if (b->globalref->mod == defining_module) {
13821401
// No backedge required - invalidation will forward scan
13831402
jl_atomic_fetch_or(&b->flags, BINDING_FLAG_ANY_IMPLICIT_EDGES);
1384-
return;
1403+
if (!(jl_atomic_fetch_or(&for_method->did_scan_source, 0x2) & 0x2))
1404+
jl_add_scanned_method(for_method->module, for_method);
1405+
return 1;
13851406
}
1386-
jl_add_binding_backedge(b, edge);
1407+
jl_add_binding_backedge(b, (jl_value_t*)edge);
1408+
return 0;
13871409
}
13881410

13891411
JL_DLLEXPORT jl_binding_partition_t *jl_replace_binding_locked(jl_binding_t *b,

src/staticdata.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,7 @@ static void jl_queue_module_for_serialization(jl_serializer_state *s, jl_module_
812812
}
813813

814814
jl_queue_for_serialization(s, m->usings_backedges);
815+
jl_queue_for_serialization(s, m->scanned_methods);
815816
}
816817

817818
// Anything that requires uniquing or fixing during deserialization needs to be "toplevel"
@@ -1324,6 +1325,9 @@ static void jl_write_module(jl_serializer_state *s, uintptr_t item, jl_module_t
13241325
newm->usings_backedges = NULL;
13251326
arraylist_push(&s->relocs_list, (void*)(reloc_offset + offsetof(jl_module_t, usings_backedges)));
13261327
arraylist_push(&s->relocs_list, (void*)backref_id(s, m->usings_backedges, s->link_ids_relocs));
1328+
newm->scanned_methods = NULL;
1329+
arraylist_push(&s->relocs_list, (void*)(reloc_offset + offsetof(jl_module_t, scanned_methods)));
1330+
arraylist_push(&s->relocs_list, (void*)backref_id(s, m->scanned_methods, s->link_ids_relocs));
13271331

13281332
// After reload, everything that has happened in this process happened semantically at
13291333
// (for .incremental) or before jl_require_world, so reset this flag.

0 commit comments

Comments
 (0)