Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.9.9

Remove manual opaque closure optimisation functions in favour of setting the world age and letting the compiler do more work for us, and providing it with some more type information. This changes no functionality, and shouldn't change performance either, but simplifies code.

# 0.9.8

Enables built docs for the current release version of Libtask.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.9.8"
version = "0.9.9"

[deps]
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"
Expand Down
40 changes: 33 additions & 7 deletions src/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ function build_callable(sig::Type{<:Tuple})
TapedTask from that function."""
throw(ArgumentError(msg))
end
key = CacheKey(Base.get_world_counter(), sig)
world_age = Base.get_world_counter()
key = CacheKey(world_age, sig)
if haskey(mc_cache, key)
return fresh_copy(mc_cache[key])
else
Expand All @@ -88,11 +89,15 @@ function build_callable(sig::Type{<:Tuple})
isva = which(sig).isva
bb, refs, types = derive_copyable_task_ir(BBCode(ir))
unoptimised_ir = IRCode(bb)
@static if VERSION > v"1.12-"
# This is a performance optimisation, copied over from Mooncake, where setting
# the valid world age to be very strictly just the current age allows the
# compiler to do more inlining and other optimisation.
unoptimised_ir = set_valid_world!(unoptimised_ir, world_age)
end
optimised_ir = optimise_ir!(unoptimised_ir)
mc_ret_type = callable_ret_type(sig, types)
mc = optimized_misty_closure(
mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true
)
mc = misty_closure(mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true)
mc_cache[key] = mc
return mc, refs[end]
end
Expand Down Expand Up @@ -820,9 +825,27 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}
deref_ids = map(phi_inds) do n
id = bb.inst_ids[n]
phi_id = phi_ids[n]
ref_ind = ssa_id_to_ref_index_map[id]
push!(
inst_pairs,
(id, new_inst(Expr(:call, deref_phi, refs_id, phi_id))),
# The last argument, ref_index_to_type_map[ref_ind], is a
# performance optimisation. The idea is that we know the inferred
# type of the PhiNode from the original IR, and by passing it to
# deref_phi we can type annotate the element type of the Ref
# that it's being dereferenced, resulting in more concrete types
# in the generated IR.
(
id,
new_inst(
Expr(
:call,
deref_phi,
refs_id,
phi_id,
ref_index_to_type_map[ref_ind],
),
),
),
)
return id
end
Expand Down Expand Up @@ -1202,8 +1225,11 @@ end
@inline resume_block_is(refs::R, id::Int32) where {R<:Tuple} = !(refs[end][] === id)

# Helper used in `derive_copyable_task_ir`.
@inline deref_phi(refs::R, n::TupleRef) where {R<:Tuple} = refs[n.n][]
@inline deref_phi(::R, x) where {R<:Tuple} = x
@inline function deref_phi(refs::R, n::TupleRef, ::Type{T}) where {R<:Tuple,T}
ref = refs[n.n]::Base.RefValue{T}
return ref[]
end
@inline deref_phi(::R, x, t::Type) where {R<:Tuple} = x

# Helper used in `derived_copyable_task_ir`.
@inline not_a_produced(x) = !(isa(x, ProducedValue))
Expand Down
88 changes: 29 additions & 59 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,55 +220,6 @@ function opaque_closure(
)::Core.OpaqueClosure{sig,ret_type}
end

function optimized_opaque_closure(rtype, ir::IRCode, env...; kwargs...)
oc = opaque_closure(rtype, ir, env...; kwargs...)
world = UInt(oc.world)
set_world_bounds_for_optimization!(oc)
optimized_oc = optimize_opaque_closure(oc, rtype, env...; kwargs...)
return optimized_oc
end

function optimize_opaque_closure(oc::Core.OpaqueClosure, rtype, env...; kwargs...)
method = oc.source
ci = method.specializations.cache
world = UInt(oc.world)
ir = reinfer_and_inline(ci, world)
ir === nothing && return oc # nothing to optimize
return opaque_closure(rtype, ir, env...; kwargs...)
end

# Allows optimization to make assumptions about binding access,
# enabling inlining and other optimizations.
function set_world_bounds_for_optimization!(oc::Core.OpaqueClosure)
ci = oc.source.specializations.cache
ci.inferred === nothing && return nothing
ci.inferred.min_world = oc.world
return ci.inferred.max_world = oc.world
end

function reinfer_and_inline(ci::Core.CodeInstance, world::UInt)
interp = CC.NativeInterpreter(world)
mi = get_mi(ci)
argtypes = collect(Any, mi.specTypes.parameters)
irsv = CC.IRInterpretationState(interp, ci, mi, argtypes, world)
irsv === nothing && return nothing
for stmt in irsv.ir.stmts
inst = stmt[:inst]
if Meta.isexpr(inst, :loopinfo) ||
Meta.isexpr(inst, :pop_exception) ||
isa(inst, CC.GotoIfNot) ||
isa(inst, CC.GotoNode) ||
Meta.isexpr(inst, :copyast)
continue
end
stmt[:flag] |= CC.IR_FLAG_REFINED
end
CC.ir_abstract_constant_propagation(interp, irsv)
state = CC.InliningState(interp)
ir = CC.ssa_inlining_pass!(irsv.ir, state, CC.propagate_inbounds(irsv))
return ir
end

"""
misty_closure(
ret_type::Type,
Expand All @@ -291,14 +242,33 @@ function misty_closure(
return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir))
end

function optimized_misty_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)
return MistyClosure(
optimized_opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir)
)
@static if VERSION > v"1.12-"
"""
set_valid_world!(ir::IRCode, world::UInt)::IRCode
(1.12+ only)
Create a shallow copy of the given IR code, with its `valid_worlds` field updated
to a single valid world. This allows the compiler to perform more inlining.
In particular, if the IR comes from say a function `f` which makes a call to another
function `g` which only got defined after `f`, then at the min_world when `f` was
defined, `g` was not available yet. If we restrict the IR to a world where `g` is
available then `g` can be inlined.
Will error if `world` is not in the existing `valid_worlds` of `ir`.
"""
function set_valid_world!(ir::IRCode, world::UInt)
if world ir.valid_worlds
error("World $world is not valid for this IRCode: $(ir.valid_worlds).")
end
return CC.IRCode(
ir.stmts,
ir.cfg,
ir.debuginfo,
ir.argtypes,
ir.meta,
ir.sptypes,
CC.WorldRange(world, world),
)
end
end
Loading