Skip to content

Commit 57b808c

Browse files
committed
Dynamic nested calls and uses of return values of calls which might produce
1 parent f94ea17 commit 57b808c

File tree

2 files changed

+130
-34
lines changed

2 files changed

+130
-34
lines changed

src/copyable_task.jl

Lines changed: 91 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -241,22 +241,34 @@ function is_produce_stmt(x)::Bool
241241
end
242242

243243
"""
244-
stmt_might_produce(x)::Bool
244+
stmt_might_produce(x, ret_type::Type)::Bool
245245
246246
`true` if `x` might contain a call to `produce`, and `false` otherwise.
247247
"""
248-
function stmt_might_produce(x)::Bool
248+
function stmt_might_produce(x, ret_type::Type)::Bool
249+
250+
# Statement will terminate in an unusual fashion, so don't bother recursing.
251+
# This isn't _strictly_ correct (there could be a `produce` statement before the
252+
# `throw` call is hit), but this seems unlikely to happen in practice. If it does, the
253+
# user should get a sensible error message anyway.
254+
ret_type == Union{} && return false
255+
256+
# Statement will terminate in the usual fashion, so _do_ bother recusing.
249257
is_produce_stmt(x) && return true
250258
Meta.isexpr(x, :invoke) && return might_produce(x.args[1].specTypes)
259+
if Meta.isexpr(x, :call)
260+
# This is a hack -- it's perfectly possible for `DataType` calls to produce in general.
261+
f = get_function(x.args[1])
262+
_might_produce = !isa(f, Union{Core.IntrinsicFunction,Core.Builtin,DataType})
263+
return _might_produce
264+
end
251265
return false
252-
253-
# # TODO: make this correct
254-
# Meta.isexpr(x, :call) &&
255-
# return !isa(x.args[1], Union{Core.IntrinsicFunction,Core.Builtin})
256-
# Meta.isexpr(x, :invoke) && return false # todo: make this more accurate
257-
# return false
258266
end
259267

268+
get_function(x) = x
269+
get_function(x::Expr) = eval(x)
270+
get_function(x::GlobalRef) = isconst(x) ? getglobal(x.mod, x.name) : x.binding
271+
260272
"""
261273
produce_value(x::Expr)
262274
@@ -347,7 +359,11 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
347359
# in the block at which the corresponding split starts and finishes.
348360
all_splits = map(ir.blocks) do block
349361
split_ends = vcat(
350-
findall(inst -> stmt_might_produce(inst.stmt), block.insts), length(block)
362+
findall(
363+
inst -> stmt_might_produce(inst.stmt, CC.widenconst(inst.type)),
364+
block.insts,
365+
),
366+
length(block),
351367
)
352368
return map(enumerate(split_ends)) do (n, split_end)
353369
return (start=(n == 1 ? 0 : split_ends[n - 1]) + 1, last=split_end)
@@ -654,48 +670,63 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
654670
push!(new_blocks, BBlock(splits_ids[n], inst_pairs))
655671

656672
# Derive TapedTask for this statement.
657-
callable = if Meta.isexpr(stmt, :invoke)
673+
(callable, callable_args) = if Meta.isexpr(stmt, :invoke)
658674
sig = stmt.args[1].specTypes
659-
LazyCallable{sig,callable_ret_type(sig)}()
675+
(LazyCallable{sig,callable_ret_type(sig)}(), stmt.args[2:end])
676+
elseif Meta.isexpr(stmt, :call)
677+
(DynamicCallable(), stmt.args)
660678
else
679+
display(stmt)
680+
println()
661681
error("unhandled statement which might produce $stmt")
662682
end
663683

684+
# Find any `ID`s and replace them with calls to read whatever is stored
685+
# in the `Ref`s that they are associated to.
686+
callable_inst_pairs = Mooncake.IDInstPair[]
687+
for (n, arg) in enumerate(callable_args)
688+
arg isa ID || continue
689+
690+
new_id = ID()
691+
ref_ind = ssa_id_to_ref_index_map[arg]
692+
expr = Expr(:call, get_ref_at, refs_id, ref_ind)
693+
push!(callable_inst_pairs, (new_id, new_inst(expr)))
694+
callable_args[n] = new_id
695+
end
696+
664697
# Allocate a slot in the _refs vector for this callable.
665698
push!(_refs, Ref(callable))
666699
callable_ind = length(_refs)
667700

668701
# Retrieve the callable from the refs.
669702
callable_id = ID()
670-
callable = Expr(:call, get_ref_at, refs_id, callable_ind)
703+
callable_stmt = Expr(:call, get_ref_at, refs_id, callable_ind)
704+
push!(callable_inst_pairs, (callable_id, new_inst(callable_stmt)))
671705

672706
# Call the callable.
673-
result = Expr(:call, callable_id, stmt.args[3:end]...)
674707
result_id = ID()
708+
result_stmt = Expr(:call, callable_id, callable_args...)
709+
push!(callable_inst_pairs, (result_id, new_inst(result_stmt)))
675710

676711
# Determine whether this TapedTask has produced a not-a-`ProducedValue`.
677-
not_produced = Expr(:call, not_a_produced, result_id)
678712
not_produced_id = ID()
713+
not_produced_stmt = Expr(:call, not_a_produced, result_id)
714+
push!(callable_inst_pairs, (not_produced_id, new_inst(not_produced_stmt)))
679715

680716
# Go to a block which just returns the `ProducedValue`, if a
681717
# `ProducedValue` is returned, otherwise continue to the next split.
682718
is_produced_block_id = ID()
683-
next_block_id = splits_ids[n + 1] # safe since the last split ends with a terminator
684-
# switch = Switch(Any[not_produced_id], [is_produced_block_id], next_block_id)
685-
switch = IDGotoIfNot(not_produced_id, is_produced_block_id)
686-
687-
# Insert a new block to hold the three previous statements.
688-
callable_inst_pairs = Mooncake.IDInstPair[
689-
(callable_id, new_inst(callable)),
690-
(result_id, new_inst(result)),
691-
(not_produced_id, new_inst(not_produced)),
692-
(ID(), new_inst(switch)),
693-
]
719+
is_not_produced_block_id = ID()
720+
switch = Switch(
721+
Any[not_produced_id],
722+
[is_produced_block_id],
723+
is_not_produced_block_id,
724+
)
725+
push!(callable_inst_pairs, (ID(), new_inst(switch)))
726+
727+
# Push the above statements onto a new block.
694728
push!(new_blocks, BBlock(callable_block_id, callable_inst_pairs))
695729

696-
goto_block = BBlock(ID(), [(ID(), new_inst(IDGotoNode(next_block_id)))])
697-
push!(new_blocks, goto_block)
698-
699730
# Construct block which handles the case that we got a `ProducedValue`. If
700731
# this happens, it means that `callable` has more things to produce still.
701732
# This means that we need to call it again next time we enter this function.
@@ -709,6 +740,21 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
709740
(return_id, new_inst(ReturnNode(result_id))),
710741
]
711742
push!(new_blocks, BBlock(is_produced_block_id, produced_block_inst_pairs))
743+
744+
# Construct block which handles the case that we did not get a
745+
# `ProducedValue`. In this case, we must first push the result to the `Ref`
746+
# associated to the call, and goto the next split.
747+
next_block_id = splits_ids[n + 1] # safe since the last split ends with a terminator
748+
result_ref_ind = ssa_id_to_ref_index_map[id]
749+
set_ref = Expr(:call, set_ref_at!, refs_id, result_ref_ind, result_id)
750+
not_produced_block_inst_pairs = Mooncake.IDInstPair[
751+
(ID(), new_inst(set_ref))
752+
(ID(), new_inst(IDGotoNode(next_block_id)))
753+
]
754+
push!(
755+
new_blocks,
756+
BBlock(is_not_produced_block_id, not_produced_block_inst_pairs),
757+
)
712758
end
713759
return new_blocks
714760
end
@@ -784,7 +830,7 @@ end
784830

785831
function (l::LazyCallable)(args::Vararg{Any,N}) where {N}
786832
isdefined(l, :mc) || construct_callable!(l)
787-
return l.mc(args...)
833+
return l.mc(args[2:end]...)
788834
end
789835

790836
function construct_callable!(l::LazyCallable{sig}) where {sig}
@@ -793,3 +839,19 @@ function construct_callable!(l::LazyCallable{sig}) where {sig}
793839
l.position = pos
794840
return nothing
795841
end
842+
843+
mutable struct DynamicCallable{V}
844+
cache::V
845+
end
846+
847+
DynamicCallable() = DynamicCallable(Dict{Any,Any}())
848+
849+
function (dynamic_callable::DynamicCallable)(args::Vararg{Any,N}) where {N}
850+
sig = Mooncake._typeof(args)
851+
callable = get(dynamic_callable.cache, sig, nothing)
852+
if callable === nothing
853+
callable = build_callable(sig)
854+
dynamic_callable.cache[sig] = callable
855+
end
856+
return callable[1](args[2:end]...)
857+
end

src/test_utils.jl

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,25 @@ function test_cases()
6868
),
6969
Testcase("dynamic scope 1", 5, (dynamic_scope_tester_1,), [5]),
7070
Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), [6]),
71-
72-
# Failing tests
73-
Testcase("nested", nothing, (nested_outer,), [true, false]),
71+
Testcase("nested (static)", nothing, (static_nested_outer,), [true, false]),
72+
Testcase(
73+
"nested (static + used)",
74+
nothing,
75+
(static_nested_outer_use_produced,),
76+
[true, 1],
77+
),
78+
Testcase(
79+
"nested (dynamic)",
80+
nothing,
81+
(dynamic_nested_outer, Ref{Any}(nested_inner)),
82+
[true, false],
83+
),
84+
Testcase(
85+
"nested (dynamic + used)",
86+
nothing,
87+
(dynamic_nested_outer_use_produced, Ref{Any}(nested_inner)),
88+
[true, 1],
89+
),
7490
]
7591
end
7692

@@ -165,15 +181,33 @@ end
165181

166182
@noinline function nested_inner()
167183
produce(true)
168-
return nothing
184+
return 1
169185
end
170186

171187
Libtask.might_produce(::Type{Tuple{typeof(nested_inner)}}) = true
172188

173-
function nested_outer()
189+
function static_nested_outer()
174190
nested_inner()
175191
produce(false)
176192
return nothing
177193
end
178194

195+
function static_nested_outer_use_produced()
196+
y = nested_inner()
197+
produce(y)
198+
return nothing
199+
end
200+
201+
function dynamic_nested_outer(f::Ref{Any})
202+
f[]()
203+
produce(false)
204+
return nothing
205+
end
206+
207+
function dynamic_nested_outer_use_produced(f::Ref{Any})
208+
y = f[]()
209+
produce(y)
210+
return nothing
211+
end
212+
179213
end

0 commit comments

Comments
 (0)