Skip to content

Commit 6aeaac3

Browse files
committed
Implement dynamic scope
1 parent b803f58 commit 6aeaac3

File tree

3 files changed

+102
-234
lines changed

3 files changed

+102
-234
lines changed

src/copyable_task.jl

Lines changed: 78 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,122 +1,79 @@
1+
const dynamic_scope = Base.ScopedValues.ScopedValue{Any}(0)
2+
3+
"""
4+
get_dynamic_scope()
5+
6+
Returns the dynamic scope associated to `Libtask`. If called from inside a `TapedTask`, this
7+
will return whatever is contained in its `dynamic_scope` field.
8+
9+
See also [`set_dynamic_scope!`](@ref).
10+
"""
11+
get_dynamic_scope() = dynamic_scope[]
12+
113
__v::Int = 5
214
@noinline function produce(x)
315
global __v = 4
416
return nothing
517
end
618

7-
mutable struct TapedTask{Tmc<:MistyClosure,Targs}
8-
const mc::Tmc
19+
function build_callable(ir::IRCode)
20+
seed_id!()
21+
bb, refs = derive_copyable_task_ir(BBCode(ir))
22+
ir = IRCode(bb)
23+
optimised_ir = Mooncake.optimise_ir!(ir)
24+
return MistyClosure(optimised_ir, refs...; do_compile=true), refs[end]
25+
end
26+
27+
mutable struct TapedTask{Tdynamic_scope,Targs,Tmc<:MistyClosure}
28+
dynamic_scope::Tdynamic_scope
929
args::Targs
10-
const position::Base.RefValue{Int32}
11-
const deepcopy_types::Type
30+
const mc::Tmc
31+
const position::Base.RefValue{Int32} # As does this
1232
end
1333

1434
"""
15-
Base.copy(t::TapedTask)
16-
17-
Makes a copy of `t` which can be run. For the most part, calls to [`consume`](@ref) on the
18-
copied task will give the same results as the original. There are, however, substantial
19-
limitations to this, detailed in the extended help.
35+
TapedTask(dynamic_scope::Any, f, args...)
2036
21-
# Extended Help
22-
23-
We call a copy of a `TapedTask` _consistent_ with the original if the call to `==` in the
24-
loop below always returns `true`:
25-
```julia
26-
t = <some_TapedTask>
27-
tc = copy(t)
28-
for (v, vc) in zip(t, tc)
29-
v == vc
30-
end
31-
```
32-
(provided that `==` is implemented for all `v` that are produced). Convesely, we refer to a
33-
copy as _inconsistent_ if this property doesn't hold. In order to ensure
34-
consistency, we need to ensure that independent copies are made of anything which might be
35-
mutated by the task or its copy during subsequent `consume` calls. Failure to do this can
36-
cause problems if, for example, a task reads-to and writes-from some memory.
37-
If we call `consume` on the original task, and then on a copy of it, any changes made by the
38-
original will be visible to the copy, potentially causing its behaviour to differ. This can
39-
manifest itself as a race condition if the task and its copies are run concurrently.
40-
41-
To understand a bit more about when a task is / is not consistent, we need to dig into the
42-
rather specific semantics of `copy`. Calling `copy` on a `TapedTask` does the following:
43-
1. `copy` the `position` field,
44-
2. `map`s `_tape_copy` over the `args` field, and
45-
3. `map`s `_tape_copy` over the all of the data closed over in the `OpaqueClosure` which
46-
implements the task (specifically the values _inside_ the `Ref`s) -- call these the
47-
`captures`. Except the last elements of this data, because this is `===` to the
48-
`position` field -- for this element we use the copy we made in step 1.
49-
50-
`_tape_copy` doesn't actually make a copy of the object at all if it is not either an
51-
`Array`, a `Ref`, or an instance of one of the types listed in the task's `deepcopy_types`
52-
field. If it is an instance of one of these types then `_tape_copy` just calls `deepcopy`.
53-
54-
This behaviour is plainly entirely acceptable if the argument to `_tape_copy` is a bits
55-
type. For any `mutable struct`s which aren't flagged for `deepcopy`ing, we have an immediate
56-
risk of inconsistency. Similarly, for any `struct` types which aren't bits types (e.g.
57-
those which contain an `Array`, `Ref`, or some other `mutable struct` either directly as one
58-
of their fields, or as a field of a field, etc), we have an inconsistency risk.
59-
60-
Furthermore, for anything which _is_ `deepcopy`ed we introduce inconsistency risks. If, for
61-
example, two elements of the data closed over by the task alias one another, calling
62-
`deepcopy` on them separately will cause the copies to _not_ alias one another.
63-
The same thing can happen if one element is `deepcopy`ed and the other not. For example, if
64-
we have both an `Array` `x` and `view(x, inds)` stored in separate elements of `captures`,
65-
`x` will be `deepcopy`ed, while `view(x, inds)` will not. In the copy of `captures`, the
66-
`view` will still be a view into the original `x`, not the `deepcopy`ed version. Again, this
67-
introduces inconsistency.
68-
69-
Why do we have these semantics? We have them because Libtask has always had them, and at the
70-
time of writing we're unsure whether AdvancedPS.jl, and by extension Turing.jl rely on this
71-
behaviour.
72-
73-
What other options do we have? Simply calling `deepcopy` on a `TapedTask` works fine, and
74-
should reliably result in consistent behaviour between a `TapedTask` and any copies of it.
75-
This would, therefore, be a preferable implementation. We should try to determine whether
76-
this is a viable option.
37+
Construct a `TapedTask` with the specified `dynamic_scope`, for function `f` and positional
38+
arguments `args`.
7739
"""
78-
function Base.copy(t::T) where {T<:TapedTask}
79-
captures = t.mc.oc.captures
80-
new_captures = map(Base.Fix2(_copy_capture, t.deepcopy_types), captures)
81-
new_position = new_captures[end] # baked in later on.
82-
new_args = map(Base.Fix2(_tape_copy, t.deepcopy_types), t.args)
83-
new_mc = Mooncake.replace_captures(t.mc, new_captures)
84-
return T(new_mc, new_args, new_position, t.deepcopy_types)
85-
end
86-
87-
function _copy_capture(r::Ref{T}, deepcopy_types::Type) where {T}
88-
new_capture = Ref{T}()
89-
if isassigned(r)
90-
new_capture[] = _tape_copy(r[], deepcopy_types)
91-
end
92-
return new_capture
40+
function TapedTask(dynamic_scope::Any, fargs...)
41+
mc, count_ref = build_callable(Base.code_ircode_by_type(typeof(fargs))[1][1])
42+
return TapedTask(dynamic_scope, fargs[2:end], mc, count_ref)
9343
end
9444

95-
_tape_copy(v, deepcopy_types::Type) = v isa deepcopy_types ? deepcopy(v) : v
96-
97-
# Not sure that we need this in the new implementation.
98-
_tape_copy(box::Core.Box, deepcopy_types::Type) = error("Found a box")
45+
"""
46+
set_dynamic_scope!(t::TapedTask, new_dynamic_scope)::Nothing
9947
100-
@inline consume(t::TapedTask) = t.mc(t.args...)
48+
Set the `dynamic_scope` of `t` to `new_dynamic_scope`. Any references to
49+
`LibTask.dynamic_scope` in future calls to `consume(t)` (either directly, or implicitly via
50+
iteration) will see this new value.
10151
102-
function initialise!(t::TapedTask, args::Vararg{Any,N})::Nothing where {N}
103-
t.position[] = -1
104-
t.args = args
52+
See also: [`get_dynamic_scope`](@ref).
53+
"""
54+
function set_dynamic_scope!(t::TapedTask{T}, new_dynamic_scope::T)::Nothing where {T}
55+
t.dynamic_scope = new_dynamic_scope
10556
return nothing
10657
end
10758

108-
function TapedTask(fargs...; deepcopy_types::Type=Union{})
109-
sig = typeof(fargs)
110-
mc, count_ref = build_callable(Base.code_ircode_by_type(sig)[1][1])
111-
return TapedTask(mc, fargs[2:end], count_ref, Union{deepcopy_types,Array,Ref})
112-
end
59+
"""
60+
Base.copy(t::TapedTask)
11361
114-
function build_callable(ir::IRCode)
115-
seed_id!()
116-
bb, refs = derive_copyable_task_ir(BBCode(ir))
117-
ir = IRCode(bb)
118-
optimised_ir = Mooncake.optimise_ir!(ir)
119-
return MistyClosure(optimised_ir, refs...; do_compile=true), refs[end]
62+
Makes a completely independent copy of `t`. `consume` can be applied to either the copy of
63+
`t` or the original without advancing the state of the other.
64+
"""
65+
Base.copy(t::T) where {T<:TapedTask} = deepcopy(t)
66+
67+
"""
68+
consume(t::TapedTask)
69+
70+
Run `t` until it makes a call to `produce`. If this is the first time that `t` has been
71+
called, it start execution from the entry point. If `consume` has previously been called on
72+
`t`, it will resume from the last `produce` call. If there are no more `produce` calls,
73+
`nothing` will be returned.
74+
"""
75+
@inline function consume(t::TapedTask)
76+
return Base.ScopedValues.with(() -> t.mc(t.args...), dynamic_scope => t.dynamic_scope)
12077
end
12178

12279
"""
@@ -288,7 +245,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
288245
n += 1
289246
ssa_id_to_ref_index_map[id] = n
290247
ref_index_to_ssa_id_map[n] = id
291-
ref_index_to_type_map[n] = stmt.type
248+
ref_index_to_type_map[n] = CC.widenconst(stmt.type)
292249
end
293250
end
294251

@@ -382,8 +339,25 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
382339
push!(inst_pairs, (id, inst))
383340
elseif stmt isa Nothing
384341
push!(inst_pairs, (id, inst))
342+
elseif stmt isa GlobalRef
343+
ref_ind = ssa_id_to_ref_index_map[id]
344+
expr = Expr(:call, set_ref_at!, refs_id, ref_ind, stmt)
345+
push!(inst_pairs, (id, new_inst(expr)))
346+
elseif stmt isa Core.PiNode
347+
if stmt.val isa ID
348+
ref_ind = ssa_id_to_ref_index_map[stmt.val]
349+
val_id = ID()
350+
expr = Expr(:call, get_ref_at, refs_id, ref_ind)
351+
push!(inst_pairs, (val_id, new_inst(expr)))
352+
push!(inst_pairs, (id, new_inst(Core.PiNode(val_id, stmt.typ))))
353+
else
354+
push!(inst_pairs, (id, inst))
355+
end
356+
set_ind = ssa_id_to_ref_index_map[id]
357+
set_expr = Expr(:call, set_ref_at!, refs_id, set_ind, id)
358+
push!(inst_pairs, (ID(), new_inst(set_expr)))
385359
else
386-
throw(error("Unhandled stmt $stmt"))
360+
throw(error("Unhandled stmt $stmt of type $(typeof(stmt))"))
387361
end
388362
end
389363

@@ -451,7 +425,9 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
451425
end
452426

453427
# Helper used in `derive_copyable_task_ir`.
454-
@inline get_ref_at(refs::R, n::Int) where {R<:Tuple} = refs[n][]
428+
@inline function get_ref_at(refs::R, n::Int) where {R<:Tuple}
429+
return refs[n][]
430+
end
455431

456432
# Helper used in `derive_copyable_task_ir`.
457433
@inline function set_ref_at!(refs::R, n::Int, val) where {R<:Tuple}

src/test_utils.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ..Libtask: TapedTask
66

77
struct Testcase
88
name::String
9+
dynamic_scope::Any
910
fargs::Tuple
1011
expected_iteration_results::Vector
1112
end
@@ -14,7 +15,7 @@ function (case::Testcase)()
1415
testset = @testset "$(case.name)" begin
1516

1617
# Construct the task.
17-
t = TapedTask(case.fargs...)
18+
t = TapedTask(case.dynamic_scope, case.fargs...)
1819

1920
# Iterate through t. Record the results, and take a copy after each iteration.
2021
iteration_results = []
@@ -39,21 +40,22 @@ function test_cases()
3940
return Testcase[
4041
Testcase(
4142
"single block",
43+
nothing,
4244
(single_block, 5.0),
4345
[sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))],
4446
),
45-
Testcase("produce old", (produce_old_value, 5.0), [sin(5.0), sin(5.0)]),
46-
Testcase("branch on old value l", (branch_on_old_value, 2.0), [true, 2.0]),
47-
Testcase("branch on old value r", (branch_on_old_value, -1.0), [false, -2.0]),
48-
Testcase("no produce", (no_produce_test, 5.0, 4.0), []),
49-
Testcase("new object", (new_object_test, 5, 4), [C(5, 4), C(5, 4)]),
50-
Testcase("branching test l", (branching_test, 5.0, 4.0), [string(sin(5.0))]),
51-
Testcase("branching test r", (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)]),
52-
Testcase("unused argument test", (unused_argument_test, 3), [1]),
53-
Testcase("test with const", (test_with_const,), [1]),
54-
Testcase("while loop", (while_loop,), collect(1:9)),
47+
Testcase("produce old", nothing, (produce_old_value, 5.0), [sin(5.0), sin(5.0)]),
48+
Testcase("branch on old value l", nothing, (branch_on_old_value, 2.0), [true, 2.0]),
49+
Testcase("branch on old value r", nothing, (branch_on_old_value, -1.0), [false, -2.0]),
50+
Testcase("no produce", nothing, (no_produce_test, 5.0, 4.0), []),
51+
Testcase("new object", nothing, (new_object_test, 5, 4), [C(5, 4), C(5, 4)]),
52+
Testcase("branching test l", nothing, (branching_test, 5.0, 4.0), [string(sin(5.0))]),
53+
Testcase("branching test r", nothing, (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)]),
54+
Testcase("unused argument test", nothing, (unused_argument_test, 3), [1]),
55+
Testcase("test with const", nothing, (test_with_const,), [1]),
56+
Testcase("while loop", nothing, (while_loop,), collect(1:9)),
5557
Testcase(
56-
"foreigncall tester", (foreigncall_tester, "hi"), [Ptr{UInt8}, Ptr{UInt8}]
58+
"foreigncall tester", nothing, (foreigncall_tester, "hi"), [Ptr{UInt8}, Ptr{UInt8}]
5759
),
5860

5961
# Failing tests

0 commit comments

Comments
 (0)