Skip to content

Commit 5bc1095

Browse files
committed
Improve detection of struct dependencies
1 parent fd5d0c1 commit 5bc1095

File tree

2 files changed

+175
-81
lines changed

2 files changed

+175
-81
lines changed

src/codeedges.jl

Lines changed: 140 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -192,18 +192,34 @@ function namedkeys(cl::CodeLinks)
192192
end
193193

194194
function direct_links!(cl::CodeLinks, src::CodeInfo)
195+
# Utility for when a stmt itself contains a CodeInfo
196+
function add_inner!(cl::CodeLinks, icl::CodeLinks, idx)
197+
for (name, _) in icl.nameassigns
198+
assigns = get(cl.nameassigns, name, nothing)
199+
if assigns === nothing
200+
cl.nameassigns[name] = assigns = Int[]
201+
end
202+
push!(assigns, idx)
203+
end
204+
for (name, _) in icl.namesuccs
205+
succs = get(cl.namesuccs, name, nothing)
206+
if succs === nothing
207+
cl.namesuccs[name] = succs = Links()
208+
end
209+
push!(succs.ssas, idx)
210+
end
211+
end
212+
195213
for (i, stmt) in enumerate(src.code)
196214
if isexpr(stmt, :thunk) && isa(stmt.args[1], CodeInfo)
197215
icl = CodeLinks(stmt.args[1])
198-
for (name, _) in icl.nameassigns
199-
assign = get(cl.nameassigns, name, nothing)
200-
if assign === nothing
201-
cl.nameassigns[name] = assign = Int[]
202-
end
203-
push!(assign, i)
204-
end
216+
add_inner!(cl, icl, i)
205217
continue
206218
elseif isa(stmt, Expr) && stmt.head trackedheads
219+
if stmt.head === :method && length(stmt.args) === 3 && isa(stmt.args[3], CodeInfo)
220+
icl = CodeLinks(stmt.args[3])
221+
add_inner!(cl, icl, i)
222+
end
207223
name = stmt.args[1]
208224
if isa(name, Symbol)
209225
assign = get(cl.nameassigns, name, nothing)
@@ -511,6 +527,27 @@ function postprint_lineedges(io::IO, idx::Int, edges::CodeEdges, bbchanged::Bool
511527
return nothing
512528
end
513529

530+
function terminal_preds(i::Int, edges::CodeEdges)
531+
function terminal_preds!(s, j, edges, covered)
532+
j covered && return s
533+
push!(covered, j)
534+
preds = edges.preds[j]
535+
if isempty(preds)
536+
push!(s, j)
537+
else
538+
for p in preds
539+
terminal_preds!(s, p, edges, covered)
540+
end
541+
end
542+
return s
543+
end
544+
s, covered = BitSet(), BitSet()
545+
push!(covered, i)
546+
for p in edges.preds[i]
547+
terminal_preds!(s, p, edges, covered)
548+
end
549+
return s
550+
end
514551

515552
"""
516553
isrequired = lines_required(obj::Union{Symbol,GlobalRef}, src::CodeInfo, edges::CodeEdges)
@@ -551,95 +588,153 @@ end
551588
function lines_required!(isrequired::AbstractVector{Bool}, objs, src::CodeInfo, edges::CodeEdges)
552589
# Do a traveral of "numbered" predecessors
553590
function add_preds!(isrequired, idx, edges::CodeEdges)
591+
changed = false
554592
preds = edges.preds[idx]
555593
for p in preds
556594
isrequired[p] && continue
557595
isrequired[p] = true
596+
changed = true
558597
add_preds!(isrequired, p, edges)
559598
end
560-
return isrequired
599+
return changed
561600
end
562601
function add_succs!(isrequired, idx, edges::CodeEdges, succs)
602+
changed = false
563603
for p in succs
564604
isrequired[p] && continue
565605
isrequired[p] = true
606+
changed = true
566607
add_succs!(isrequired, p, edges, edges.succs[p])
567608
end
568-
return isrequired
609+
return changed
610+
end
611+
function add_obj!(isrequired, objs, obj, edges::CodeEdges)
612+
changed = false
613+
for d in edges.byname[obj].assigned
614+
isrequired[d] || add_preds!(isrequired, d, edges)
615+
isrequired[d] = true
616+
changed = true
617+
end
618+
push!(objs, obj)
619+
return changed
569620
end
570621

622+
objsnew = Set{Union{Symbol,GlobalRef}}()
623+
for obj in objs
624+
add_obj!(isrequired, objsnew, obj, edges)
625+
end
626+
objs = objsnew
571627
bbs = Core.Compiler.compute_basic_blocks(src.code) # needed for control-flow analysis
572628
changed = true
573629
iter = 0
574630
while changed
575631
changed = false
576-
# Add "named" object dependencies
577-
for obj in objs
578-
def = edges.byname[obj].assigned
579-
if !all(i->isrequired[i], def)
580-
changed = true
581-
for d in def
582-
isrequired[d] = true
583-
add_preds!(isrequired, d, edges)
584-
if isexpr(src.code[d], :thunk) && startswith(String(obj), '#')
585-
# For anonymous types, we also want their associated methods
586-
add_succs!(isrequired, d, edges, edges.byname[obj].succs)
587-
end
588-
end
589-
end
590-
end
591-
# Add "numbered" dependencies
632+
# Handle ssa predecessors
592633
for idx = 1:length(isrequired)
593634
if isrequired[idx]
594-
preds = edges.preds[idx]
595-
if !all(i->isrequired[i], preds)
596-
changed = true
597-
isrequired[preds] .= true
598-
end
635+
changed |= add_preds!(isrequired, idx, edges)
636+
end
637+
end
638+
# Handle named dependencies
639+
for (obj, uses) in edges.byname
640+
obj objs && continue
641+
if any(view(isrequired, uses.succs))
642+
changed |= add_obj!(isrequired, objs, obj, edges)
599643
end
600644
end
601645
# Add control-flow. For any basic block with an evaluated statement inside it,
602646
# check to see if the block has any successors, and if so mark that block's exit statement.
603647
# Likewise, any preceding blocks should have *their* exit statement marked.
604-
for (i, bb) in enumerate(bbs.blocks)
648+
for (ibb, bb) in enumerate(bbs.blocks)
605649
r = rng(bb)
606650
if any(view(isrequired, r))
607651
# if !isempty(bb.succs)
608-
if i != length(bbs.blocks)
652+
if ibb != length(bbs.blocks)
609653
idxlast = r[end]
610654
changed |= !isrequired[idxlast]
611655
isrequired[idxlast] = true
612656
end
613-
for ibb in bb.preds
614-
rpred = rng(bbs.blocks[ibb])
657+
for ibbp in bb.preds
658+
rpred = rng(bbs.blocks[ibbp])
615659
idxlast = rpred[end]
616660
changed |= !isrequired[idxlast]
617661
isrequired[idxlast] = true
618662
end
619-
for ibb in bb.succs
620-
ibb == length(bbs.blocks) && continue
621-
rpred = rng(bbs.blocks[ibb])
663+
for ibbs in bb.succs
664+
ibbs == length(bbs.blocks) && continue
665+
rpred = rng(bbs.blocks[ibbs])
622666
idxlast = rpred[end]
623667
changed |= !isrequired[idxlast]
624668
isrequired[idxlast] = true
625669
end
626670
end
627671
end
628-
# In preparation for the next round, add any new named objects
629-
# required by these dependencies
630-
for (obj, uses) in edges.byname
631-
obj objs && continue
632-
if any(i->isrequired[i], uses.succs)
633-
push!(objs, obj)
634-
changed = true
672+
# So far, everything is generic graph traversal. Now we add some domain-specific information.
673+
# New struct definitions, including their constructors, get spread out over many
674+
# statements. If we're evaluating any of them, it's important to evaluate *all* of them.
675+
for (idx, stmt) in enumerate(src.code)
676+
isrequired[idx] || continue
677+
if isexpr(stmt, :(=))
678+
stmt = stmt.args[2]
679+
end
680+
# Is this a struct definition?
681+
if (isa(stmt, Expr) && stmt.head structheads) || # < Julia 1.5
682+
(isexpr(stmt, :call) && callee_matches(stmt.args[1], Core, :_structtype)) # >= Julia 1.5
683+
stmt = stmt::Expr
684+
name = stmt.args[stmt.head === :call ? 3 : 1]
685+
if isa(name, QuoteNode)
686+
name = name.value
687+
end
688+
name = name::NamedVar
689+
# Some lines we need have been marked as successors of this name
690+
for d in edges.byname[name].succs
691+
stmt2 = src.code[d]
692+
if isa(stmt2, Expr)
693+
head = stmt2.head
694+
if head === :method || head === :global || head === :const
695+
changed |= !isrequired[d]
696+
isrequired[d] = true
697+
end
698+
end
699+
# Julia 1.5+: others are successor of a slotnum->ssa load
700+
if isslotnum(stmt2)
701+
for s in edges.succs[d]
702+
stmt3 = src.code[s]
703+
if isexpr(stmt3, :call) && (callee_matches(stmt3.args[1], Core, :_setsuper!) ||
704+
callee_matches(stmt3.args[1], Core, :_typebody!))
705+
changed |= !isrequired[s]
706+
isrequired[s] = true
707+
end
708+
end
709+
end
710+
# Julia 1.5+: for non-parametric types, the Core._setsuper! call happens without the slotname->ssa load
711+
if isexpr(stmt2, :call) && callee_matches((stmt2::Expr).args[1], Core, :_setsuper!)
712+
changed |= !isrequired[d]
713+
isrequired[d] = true
714+
end
715+
end
716+
end
717+
# Anonymous functions may not yet include the method definition
718+
if isexpr(stmt, :thunk) && isanonymous_typedef(stmt.args[1])
719+
i = idx + 1
720+
while i <= length(src.code) && !ismethod3(src.code[i])
721+
i += 1
722+
end
723+
if i <= length(src.code) && src.code[i].args[1] == false
724+
tpreds = terminal_preds(i, edges)
725+
if minimum(tpreds) == idx
726+
changed |= !isrequired[i]
727+
isrequired[i] = true
728+
end
729+
end
635730
end
636731
end
637732
iter += 1 # just for diagnostics
638733
end
639734
return isrequired
640735
end
641736

642-
function caller_matches(f, mod, sym)
737+
function callee_matches(f, mod, sym)
643738
is_global_ref(f, mod, sym) && return true
644739
if isdefined(mod, sym)
645740
is_quotenode(f, getfield(mod, sym)) && return true

test/codeedges.jl

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,17 @@
11
using LoweredCodeUtils
22
using LoweredCodeUtils.JuliaInterpreter
3-
using LoweredCodeUtils: caller_matches
3+
using LoweredCodeUtils: callee_matches
44
using JuliaInterpreter: is_global_ref, is_quotenode
55
using Test
66

7-
# # Utilities for finding statements corresponding to particular calls
8-
# function callpredicate(@nospecialize(stmt), fsym)
9-
# isa(stmt, Expr) || return false
10-
# head = stmt.head
11-
# if head === :(=)
12-
# return callpredicate(stmt.args[2], fsym)
13-
# end
14-
# return LoweredCodeUtils.iscallto(stmt, fsym)
15-
# end
16-
177
function hastrackedexpr(stmt; heads=LoweredCodeUtils.trackedheads)
188
haseval = false
199
if isa(stmt, Expr)
2010
if stmt.head === :call
2111
f = stmt.args[1]
22-
haseval = f === :eval || (caller_matches(f, Base, :getproperty) && is_quotenode(stmt.args[2], :eval))
23-
caller_matches(f, Core, :_typebody!) && return true, haseval
24-
caller_matches(f, Core, :_setsuper!) && return true, haseval
12+
haseval = f === :eval || (callee_matches(f, Base, :getproperty) && is_quotenode(stmt.args[2], :eval))
13+
callee_matches(f, Core, :_typebody!) && return true, haseval
14+
callee_matches(f, Core, :_setsuper!) && return true, haseval
2515
f === :include && return true, haseval
2616
elseif stmt.head === :thunk
2717
any(s->any(hastrackedexpr(s; heads=heads)), stmt.args[1].code) && return true, haseval
@@ -32,32 +22,19 @@ function hastrackedexpr(stmt; heads=LoweredCodeUtils.trackedheads)
3222
return false, haseval
3323
end
3424

35-
function minimal_evaluation(src::Core.CodeInfo, edges::CodeEdges)
36-
musteval = fill(false, length(src.code))
25+
function minimal_evaluation(predicate, src::Core.CodeInfo, edges::CodeEdges)
26+
isrequired = fill(false, length(src.code))
3727
for (i, stmt) in enumerate(src.code)
38-
if !musteval[i]
39-
musteval[i], haseval = hastrackedexpr(stmt)
28+
if !isrequired[i]
29+
isrequired[i], haseval = predicate(stmt)
4030
if haseval
41-
musteval[edges.succs[i]] .= true
31+
isrequired[edges.succs[i]] .= true
4232
end
4333
end
4434
end
4535
# All tracked expressions are marked. Now add their dependencies.
46-
lines_required!(musteval, src, edges)
47-
# Struct definitions likely omitted const/global. Look for them via name.
48-
for (name, var) in edges.byname
49-
if !isempty(var.assigned) && any(i->musteval[i], var.succs)
50-
foreach(var.succs) do i
51-
stmt = src.code[i]
52-
if isa(stmt, Expr)
53-
if stmt.head === :global || stmt.head === :const
54-
musteval[i] = true
55-
end
56-
end
57-
end
58-
end
59-
end
60-
return musteval
36+
lines_required!(isrequired, src, edges)
37+
return isrequired
6138
end
6239

6340
function allmissing(mod::Module, names)
@@ -209,18 +186,40 @@ end
209186
# Check that the StructParent name is discovered everywhere it is used
210187
var = edges.byname[:StructParent]
211188
@test var.preds[end] var.succs
212-
isrequired = minimal_evaluation(src, edges)
189+
isrequired = minimal_evaluation(hastrackedexpr, src, edges)
213190
selective_eval_fromstart!(frame, isrequired, true)
214191
@test supertype(ModSelective.StructParent) === AbstractArray
215192
# Also check redefinition (it's OK when the definition doesn't change)
216193
Core.eval(ModEval, ex)
217194
frame = JuliaInterpreter.prepare_thunk(ModEval, ex)
218195
src = frame.framecode.src
219196
edges = CodeEdges(src)
220-
isrequired = minimal_evaluation(src, edges)
197+
isrequired = minimal_evaluation(hastrackedexpr, src, edges)
221198
selective_eval_fromstart!(frame, isrequired, true)
222199
@test supertype(ModEval.StructParent) === AbstractArray
223200

201+
# Finding all dependencies in a struct definition
202+
# Nonparametric
203+
ex = :(struct NoParam end)
204+
frame = JuliaInterpreter.prepare_thunk(ModSelective, ex)
205+
src = frame.framecode.src
206+
edges = CodeEdges(src)
207+
isrequired = minimal_evaluation(stmt->(LoweredCodeUtils.ismethod3(stmt)&&stmt.args[1]===:NoParam,false), src, edges) # initially mark only the constructor
208+
selective_eval_fromstart!(frame, isrequired, true)
209+
@test isa(ModSelective.NoParam(), ModSelective.NoParam)
210+
# Parametric
211+
ex = quote
212+
struct Struct{T} <: StructParent{T,1}
213+
x::Vector{T}
214+
end
215+
end
216+
frame = JuliaInterpreter.prepare_thunk(ModSelective, ex)
217+
src = frame.framecode.src
218+
edges = CodeEdges(src)
219+
isrequired = minimal_evaluation(stmt->(LoweredCodeUtils.ismethod3(stmt)&&stmt.args[1]===:Struct,false), src, edges) # initially mark only the constructor
220+
selective_eval_fromstart!(frame, isrequired, true)
221+
@test isa(ModSelective.Struct([1,2,3]), ModSelective.Struct{Int})
222+
224223
# Anonymous functions
225224
ex = :(max_values(T::Union{map(X -> Type{X}, Base.BitIntegerSmall_types)...}) = 1 << (8*sizeof(T)))
226225
frame = JuliaInterpreter.prepare_thunk(ModSelective, ex)

0 commit comments

Comments
 (0)