Skip to content

Commit a3e259e

Browse files
mhauruTechnici4n
andauthored
Use set_valid_worlds! instead of manual optimisation (#205)
* Try out set_valid_worlds! instead of manual optimisation * Add a comment * Give more type information to deref_phi. * Code formatting * Add a comment * Bump patch version to 0.9.8 --------- Co-authored-by: Bruno Ploumhans <[email protected]>
1 parent f2d2ecd commit a3e259e

File tree

4 files changed

+67
-67
lines changed

4 files changed

+67
-67
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# 0.9.9
2+
3+
Remove manual opaque closure optimisation functions in favour of setting the world age and letting the compiler do more work for us, and providing it with some more type information. This changes no functionality, and shouldn't change performance either, but simplifies code.
4+
15
# 0.9.8
26

37
Enables built docs for the current release version of Libtask.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
33
license = "MIT"
44
desc = "Tape based task copying in Turing"
55
repo = "https://github.com/TuringLang/Libtask.jl.git"
6-
version = "0.9.8"
6+
version = "0.9.9"
77

88
[deps]
99
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"

src/copyable_task.jl

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ function build_callable(sig::Type{<:Tuple})
7979
TapedTask from that function."""
8080
throw(ArgumentError(msg))
8181
end
82-
key = CacheKey(Base.get_world_counter(), sig)
82+
world_age = Base.get_world_counter()
83+
key = CacheKey(world_age, sig)
8384
if haskey(mc_cache, key)
8485
return fresh_copy(mc_cache[key])
8586
else
@@ -88,11 +89,15 @@ function build_callable(sig::Type{<:Tuple})
8889
isva = which(sig).isva
8990
bb, refs, types = derive_copyable_task_ir(BBCode(ir))
9091
unoptimised_ir = IRCode(bb)
92+
@static if VERSION > v"1.12-"
93+
# This is a performance optimisation, copied over from Mooncake, where setting
94+
# the valid world age to be very strictly just the current age allows the
95+
# compiler to do more inlining and other optimisation.
96+
unoptimised_ir = set_valid_world!(unoptimised_ir, world_age)
97+
end
9198
optimised_ir = optimise_ir!(unoptimised_ir)
9299
mc_ret_type = callable_ret_type(sig, types)
93-
mc = optimized_misty_closure(
94-
mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true
95-
)
100+
mc = misty_closure(mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true)
96101
mc_cache[key] = mc
97102
return mc, refs[end]
98103
end
@@ -820,9 +825,27 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}
820825
deref_ids = map(phi_inds) do n
821826
id = bb.inst_ids[n]
822827
phi_id = phi_ids[n]
828+
ref_ind = ssa_id_to_ref_index_map[id]
823829
push!(
824830
inst_pairs,
825-
(id, new_inst(Expr(:call, deref_phi, refs_id, phi_id))),
831+
# The last argument, ref_index_to_type_map[ref_ind], is a
832+
# performance optimisation. The idea is that we know the inferred
833+
# type of the PhiNode from the original IR, and by passing it to
834+
# deref_phi we can type annotate the element type of the Ref
835+
# that it's being dereferenced, resulting in more concrete types
836+
# in the generated IR.
837+
(
838+
id,
839+
new_inst(
840+
Expr(
841+
:call,
842+
deref_phi,
843+
refs_id,
844+
phi_id,
845+
ref_index_to_type_map[ref_ind],
846+
),
847+
),
848+
),
826849
)
827850
return id
828851
end
@@ -1202,8 +1225,11 @@ end
12021225
@inline resume_block_is(refs::R, id::Int32) where {R<:Tuple} = !(refs[end][] === id)
12031226

12041227
# Helper used in `derive_copyable_task_ir`.
1205-
@inline deref_phi(refs::R, n::TupleRef) where {R<:Tuple} = refs[n.n][]
1206-
@inline deref_phi(::R, x) where {R<:Tuple} = x
1228+
@inline function deref_phi(refs::R, n::TupleRef, ::Type{T}) where {R<:Tuple,T}
1229+
ref = refs[n.n]::Base.RefValue{T}
1230+
return ref[]
1231+
end
1232+
@inline deref_phi(::R, x, t::Type) where {R<:Tuple} = x
12071233

12081234
# Helper used in `derived_copyable_task_ir`.
12091235
@inline not_a_produced(x) = !(isa(x, ProducedValue))

src/utils.jl

Lines changed: 29 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -220,55 +220,6 @@ function opaque_closure(
220220
)::Core.OpaqueClosure{sig,ret_type}
221221
end
222222

223-
function optimized_opaque_closure(rtype, ir::IRCode, env...; kwargs...)
224-
oc = opaque_closure(rtype, ir, env...; kwargs...)
225-
world = UInt(oc.world)
226-
set_world_bounds_for_optimization!(oc)
227-
optimized_oc = optimize_opaque_closure(oc, rtype, env...; kwargs...)
228-
return optimized_oc
229-
end
230-
231-
function optimize_opaque_closure(oc::Core.OpaqueClosure, rtype, env...; kwargs...)
232-
method = oc.source
233-
ci = method.specializations.cache
234-
world = UInt(oc.world)
235-
ir = reinfer_and_inline(ci, world)
236-
ir === nothing && return oc # nothing to optimize
237-
return opaque_closure(rtype, ir, env...; kwargs...)
238-
end
239-
240-
# Allows optimization to make assumptions about binding access,
241-
# enabling inlining and other optimizations.
242-
function set_world_bounds_for_optimization!(oc::Core.OpaqueClosure)
243-
ci = oc.source.specializations.cache
244-
ci.inferred === nothing && return nothing
245-
ci.inferred.min_world = oc.world
246-
return ci.inferred.max_world = oc.world
247-
end
248-
249-
function reinfer_and_inline(ci::Core.CodeInstance, world::UInt)
250-
interp = CC.NativeInterpreter(world)
251-
mi = get_mi(ci)
252-
argtypes = collect(Any, mi.specTypes.parameters)
253-
irsv = CC.IRInterpretationState(interp, ci, mi, argtypes, world)
254-
irsv === nothing && return nothing
255-
for stmt in irsv.ir.stmts
256-
inst = stmt[:inst]
257-
if Meta.isexpr(inst, :loopinfo) ||
258-
Meta.isexpr(inst, :pop_exception) ||
259-
isa(inst, CC.GotoIfNot) ||
260-
isa(inst, CC.GotoNode) ||
261-
Meta.isexpr(inst, :copyast)
262-
continue
263-
end
264-
stmt[:flag] |= CC.IR_FLAG_REFINED
265-
end
266-
CC.ir_abstract_constant_propagation(interp, irsv)
267-
state = CC.InliningState(interp)
268-
ir = CC.ssa_inlining_pass!(irsv.ir, state, CC.propagate_inbounds(irsv))
269-
return ir
270-
end
271-
272223
"""
273224
misty_closure(
274225
ret_type::Type,
@@ -291,14 +242,33 @@ function misty_closure(
291242
return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir))
292243
end
293244

294-
function optimized_misty_closure(
295-
ret_type::Type,
296-
ir::IRCode,
297-
@nospecialize env...;
298-
isva::Bool=false,
299-
do_compile::Bool=true,
300-
)
301-
return MistyClosure(
302-
optimized_opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir)
303-
)
245+
@static if VERSION > v"1.12-"
246+
"""
247+
set_valid_world!(ir::IRCode, world::UInt)::IRCode
248+
249+
(1.12+ only)
250+
Create a shallow copy of the given IR code, with its `valid_worlds` field updated
251+
to a single valid world. This allows the compiler to perform more inlining.
252+
253+
In particular, if the IR comes from say a function `f` which makes a call to another
254+
function `g` which only got defined after `f`, then at the min_world when `f` was
255+
defined, `g` was not available yet. If we restrict the IR to a world where `g` is
256+
available then `g` can be inlined.
257+
258+
Will error if `world` is not in the existing `valid_worlds` of `ir`.
259+
"""
260+
function set_valid_world!(ir::IRCode, world::UInt)
261+
if world ir.valid_worlds
262+
error("World $world is not valid for this IRCode: $(ir.valid_worlds).")
263+
end
264+
return CC.IRCode(
265+
ir.stmts,
266+
ir.cfg,
267+
ir.debuginfo,
268+
ir.argtypes,
269+
ir.meta,
270+
ir.sptypes,
271+
CC.WorldRange(world, world),
272+
)
273+
end
304274
end

0 commit comments

Comments
 (0)