Skip to content

Commit f7d6c45

Browse files
committed
make norequire configurable
This PR makes `lines_required!`'s `norequire` logic more configurable. That means the `exclude_named_typedefs` option is now got abstracted, and each consumer can implement its own strategy to escape from the required statement completion by control flow traversal. The motivation for this change is we usually want to respect a control flow in general context, but some consumer may not want that. Especially, JET doesn't want to interpret all the statements within a `try/catch` block, but just select those involved with a method definition. (issue: <aviatesk/JET.jl#150>) For example, `lines_required!` selects statements in the snippet below as JET expects: ```julia for fname in (:foo, :bar, :baz) @eval begin @inline ($(Symbol("is", fname)))(a) = a === $(QuoteNode(fname)) end end ``` , but in the example below, `lines_required` selects "too much" statements and we need a customized `norequire`: ```julia try foo(a) = sum(a) # should be selected (selected initially) foo("julia") # shouldn't be selected, but `lines_required` will select this catch err err # shouldn't be selected, but `lines_required` will select this end ``` Here is an example usage of this customizable `norequire` logic: <aviatesk/JET.jl#152> --- One downside of this change is that now we need to performa the basic block traversal twice when using `norequire = exclude_named_typedefs(src, edges)`. As far as I confirmed, this computation would never be a performance bottleneck, and thus this change hopefully won't hurt the performance. I tried to profile the time with the following snippet: ```julia function select_statements(n, src) for _ in 1:n stmts = src.code isrq = rand(Bool, length(stmts)) edges = CodeEdges(src) norequire = LoweredCodeUtils.exclude_named_typedefs(src, edges) lines_required!(isrq, src, edges, norequire) end end src = code_lowered(...) @profiler select_statements(100, src) ```
1 parent d61d111 commit f7d6c45

File tree

3 files changed

+41
-23
lines changed

3 files changed

+41
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoweredCodeUtils"
22
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
33
authors = ["Tim Holy <[email protected]>"]
4-
version = "1.2.9"
4+
version = "1.3.0"
55

66
[deps]
77
JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"

src/codeedges.jl

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -567,44 +567,61 @@ will end up skipping a subset of such statements, perhaps while repeating others
567567
568568
See also [`lines_required!`](@ref) and [`selective_eval!`](@ref).
569569
"""
570-
function lines_required(obj::Union{Symbol,GlobalRef}, src::CodeInfo, edges::CodeEdges; kwargs...)
570+
function lines_required(obj::Union{Symbol,GlobalRef}, src::CodeInfo, edges::CodeEdges, args...)
571571
isrequired = falses(length(edges.preds))
572572
objs = Set{Union{Symbol,GlobalRef}}([obj])
573-
return lines_required!(isrequired, objs, src, edges; kwargs...)
573+
return lines_required!(isrequired, objs, src, edges, args...)
574574
end
575575

576-
function lines_required(idx::Int, src::CodeInfo, edges::CodeEdges; kwargs...)
576+
function lines_required(idx::Int, src::CodeInfo, edges::CodeEdges, args...)
577577
isrequired = falses(length(edges.preds))
578578
isrequired[idx] = true
579579
objs = Set{Union{Symbol,GlobalRef}}()
580-
return lines_required!(isrequired, src, edges; kwargs...)
580+
return lines_required!(isrequired, objs, src, edges, args...)
581581
end
582582

583583
"""
584-
lines_required!(isrequired::AbstractVector{Bool}, src::CodeInfo, edges::CodeEdges; exclude_named_typedefs::Bool=false)
584+
lines_required!(isrequired::AbstractVector{Bool}, src::CodeInfo, edges::CodeEdges, norequire = ())
585585
586586
Like `lines_required`, but where `isrequired[idx]` has already been set to `true` for all statements
587587
that you know you need to evaluate. All other statements should be marked `false` at entry.
588588
On return, the complete set of required statements will be marked `true`.
589589
590-
Use `exclude_named_typedefs=true` if you're extracting method signatures and not evaluating new definitions.
590+
`norequire` specifies statements (represented as iterator of `Int`s) that should _not_ be
591+
marked as a requirement.
592+
For example, use `norequire = LoweredCodeUtils.exclude_named_typedefs(src, edges)` if you're
593+
extracting method signatures and not evaluating new definitions.
591594
"""
592-
function lines_required!(isrequired::AbstractVector{Bool}, src::CodeInfo, edges::CodeEdges; kwargs...)
595+
function lines_required!(isrequired::AbstractVector{Bool}, src::CodeInfo, edges::CodeEdges, norequire = ())
593596
objs = Set{Union{Symbol,GlobalRef}}()
594-
return lines_required!(isrequired, objs, src, edges; kwargs...)
597+
return lines_required!(isrequired, objs, src, edges, norequire)
595598
end
596599

597-
function lines_required!(isrequired::AbstractVector{Bool}, objs, src::CodeInfo, edges::CodeEdges; exclude_named_typedefs::Bool=false)
600+
function exclude_named_typedefs(src::CodeInfo, edges::CodeEdges)
601+
norequire = BitSet()
602+
i = 1
603+
nstmts = length(src.code)
604+
while i <= nstmts
605+
stmt = rhs(src.code[i])
606+
if istypedef(stmt) && !isanonymous_typedef(stmt::Expr)
607+
r = typedef_range(src, i)
608+
pushall!(norequire, r)
609+
i = last(r)+1
610+
else
611+
i += 1
612+
end
613+
end
614+
return norequire
615+
end
616+
617+
function lines_required!(isrequired::AbstractVector{Bool}, objs, src::CodeInfo, edges::CodeEdges, norequire = ())
598618
# Do a traveral of "numbered" predecessors
599619
# We'll mostly use generic graph traversal to discover all the lines we need,
600620
# but structs are in a bit of a different category (especially on Julia 1.5+).
601621
# It's easiest to discover these at the beginning.
602-
# Moreover, if we're excluding named type definitions, we'll add them to `norequire`
603-
# to prevent them from being marked.
604622
typedef_blocks, typedef_names = UnitRange{Int}[], Symbol[]
605-
norequire = BitSet()
606-
nstmts = length(src.code)
607623
i = 1
624+
nstmts = length(src.code)
608625
while i <= nstmts
609626
stmt = rhs(src.code[i])
610627
if istypedef(stmt) && !isanonymous_typedef(stmt::Expr)
@@ -618,9 +635,6 @@ function lines_required!(isrequired::AbstractVector{Bool}, objs, src::CodeInfo,
618635
isa(name, Symbol) || @show src i r stmt
619636
push!(typedef_names, name::Symbol)
620637
i = last(r)+1
621-
if exclude_named_typedefs && !isanonymous_typedef(stmt)
622-
pushall!(norequire, r)
623-
end
624638
else
625639
i += 1
626640
end
@@ -641,19 +655,22 @@ function lines_required!(isrequired::AbstractVector{Bool}, objs, src::CodeInfo,
641655
iter = 0
642656
while changed
643657
changed = false
658+
644659
# Handle ssa predecessors
645660
for idx = 1:nstmts
646661
if isrequired[idx]
647662
changed |= add_preds!(isrequired, idx, edges, norequire)
648663
end
649664
end
665+
650666
# Handle named dependencies
651667
for (obj, uses) in edges.byname
652668
obj objs && continue
653669
if any(view(isrequired, uses.succs))
654670
changed |= add_obj!(isrequired, objs, obj, edges, norequire)
655671
end
656672
end
673+
657674
# Add control-flow. For any basic block with an evaluated statement inside it,
658675
# check to see if the block has any successors, and if so mark that block's exit statement.
659676
# Likewise, any preceding blocks should have *their* exit statement marked.
@@ -684,6 +701,7 @@ function lines_required!(isrequired::AbstractVector{Bool}, objs, src::CodeInfo,
684701
end
685702
end
686703
end
704+
687705
# So far, everything is generic graph traversal. Now we add some domain-specific information.
688706
# New struct definitions, including their constructors, get spread out over many
689707
# statements. If we're evaluating any of them, it's important to evaluate *all* of them.
@@ -833,7 +851,7 @@ Mark each line of code with its requirement status.
833851
function print_with_code(io::IO, src::CodeInfo, isrequired::AbstractVector{Bool})
834852
nd = ndigits(length(isrequired))
835853
preprint(::IO) = nothing
836-
preprint(io::IO, idx::Int) = print(io, lpad(idx, nd), ' ', isrequired[idx] ? "t " : "f ")
854+
preprint(io::IO, idx::Int) = (c = isrequired[idx]; printstyled(io, lpad(idx, nd), ' ', c ? "t " : "f "; color = c ? :cyan : :plain))
837855
postprint(::IO) = nothing
838856
postprint(io::IO, idx::Int, bbchanged::Bool) = nothing
839857

test/codeedges.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LoweredCodeUtils
22
using LoweredCodeUtils.JuliaInterpreter
3-
using LoweredCodeUtils: callee_matches, istypedef
3+
using LoweredCodeUtils: callee_matches, istypedef, exclude_named_typedefs
44
using JuliaInterpreter: is_global_ref, is_quotenode
55
using Test
66

@@ -22,7 +22,7 @@ function hastrackedexpr(stmt; heads=LoweredCodeUtils.trackedheads)
2222
return false, haseval
2323
end
2424

25-
function minimal_evaluation(predicate, src::Core.CodeInfo, edges::CodeEdges; kwargs...)
25+
function minimal_evaluation(predicate, src::Core.CodeInfo, edges::CodeEdges, args...)
2626
isrequired = fill(false, length(src.code))
2727
for (i, stmt) in enumerate(src.code)
2828
if !isrequired[i]
@@ -33,7 +33,7 @@ function minimal_evaluation(predicate, src::Core.CodeInfo, edges::CodeEdges; kwa
3333
end
3434
end
3535
# All tracked expressions are marked. Now add their dependencies.
36-
lines_required!(isrequired, src, edges; kwargs...)
36+
lines_required!(isrequired, src, edges, args...)
3737
return isrequired
3838
end
3939

@@ -262,7 +262,7 @@ end
262262
frame = Frame(ModEval, ex)
263263
src = frame.framecode.src
264264
edges = CodeEdges(src)
265-
isrequired = minimal_evaluation(stmt->(LoweredCodeUtils.ismethod3(stmt),false), src, edges; exclude_named_typedefs=true) # initially mark only the constructor
265+
isrequired = minimal_evaluation(stmt->(LoweredCodeUtils.ismethod3(stmt),false), src, edges, exclude_named_typedefs(src, edges)) # initially mark only the constructor
266266
bbs = Core.Compiler.compute_basic_blocks(src.code)
267267
for (iblock, block) in enumerate(bbs.blocks)
268268
r = LoweredCodeUtils.rng(block)
@@ -301,7 +301,7 @@ end
301301
src = thk.args[1]
302302
edges = CodeEdges(src)
303303
idx = findfirst(stmt->Meta.isexpr(stmt, :method), src.code)
304-
lr = lines_required(idx, src, edges; exclude_named_typedefs=true)
304+
lr = lines_required(idx, src, edges, exclude_named_typedefs(src, edges))
305305
idx = findfirst(stmt->Meta.isexpr(stmt, :(=)) && Meta.isexpr(stmt.args[2], :call) && is_global_ref(stmt.args[2].args[1], Core, :Box), src.code)
306306
@test lr[idx]
307307
# but make sure we don't break primitivetype & abstracttype (https://github.com/timholy/Revise.jl/pull/611)

0 commit comments

Comments
 (0)