@@ -241,22 +241,34 @@ function is_produce_stmt(x)::Bool
241241end
242242
243243"""
244- stmt_might_produce(x)::Bool
244+ stmt_might_produce(x, ret_type::Type )::Bool
245245
246246`true` if `x` might contain a call to `produce`, and `false` otherwise.
247247"""
248- function stmt_might_produce (x):: Bool
248+ function stmt_might_produce (x, ret_type:: Type ):: Bool
249+
250+ # Statement will terminate in an unusual fashion, so don't bother recursing.
251+ # This isn't _strictly_ correct (there could be a `produce` statement before the
252+ # `throw` call is hit), but this seems unlikely to happen in practice. If it does, the
253+ # user should get a sensible error message anyway.
254+ ret_type == Union{} && return false
255+
256+ # Statement will terminate in the usual fashion, so _do_ bother recusing.
249257 is_produce_stmt (x) && return true
250258 Meta. isexpr (x, :invoke ) && return might_produce (x. args[1 ]. specTypes)
259+ if Meta. isexpr (x, :call )
260+ # This is a hack -- it's perfectly possible for `DataType` calls to produce in general.
261+ f = get_function (x. args[1 ])
262+ _might_produce = ! isa (f, Union{Core. IntrinsicFunction,Core. Builtin,DataType})
263+ return _might_produce
264+ end
251265 return false
252-
253- # # TODO : make this correct
254- # Meta.isexpr(x, :call) &&
255- # return !isa(x.args[1], Union{Core.IntrinsicFunction,Core.Builtin})
256- # Meta.isexpr(x, :invoke) && return false # todo: make this more accurate
257- # return false
258266end
259267
268+ get_function (x) = x
269+ get_function (x:: Expr ) = eval (x)
270+ get_function (x:: GlobalRef ) = isconst (x) ? getglobal (x. mod, x. name) : x. binding
271+
260272"""
261273 produce_value(x::Expr)
262274
@@ -347,7 +359,11 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
347359 # in the block at which the corresponding split starts and finishes.
348360 all_splits = map (ir. blocks) do block
349361 split_ends = vcat (
350- findall (inst -> stmt_might_produce (inst. stmt), block. insts), length (block)
362+ findall (
363+ inst -> stmt_might_produce (inst. stmt, CC. widenconst (inst. type)),
364+ block. insts,
365+ ),
366+ length (block),
351367 )
352368 return map (enumerate (split_ends)) do (n, split_end)
353369 return (start= (n == 1 ? 0 : split_ends[n - 1 ]) + 1 , last= split_end)
@@ -654,48 +670,63 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
654670 push! (new_blocks, BBlock (splits_ids[n], inst_pairs))
655671
656672 # Derive TapedTask for this statement.
657- callable = if Meta. isexpr (stmt, :invoke )
673+ ( callable, callable_args) = if Meta. isexpr (stmt, :invoke )
658674 sig = stmt. args[1 ]. specTypes
659- LazyCallable {sig,callable_ret_type(sig)} ()
675+ (LazyCallable {sig,callable_ret_type(sig)} (), stmt. args[2 : end ])
676+ elseif Meta. isexpr (stmt, :call )
677+ (DynamicCallable (), stmt. args)
660678 else
679+ display (stmt)
680+ println ()
661681 error (" unhandled statement which might produce $stmt " )
662682 end
663683
684+ # Find any `ID`s and replace them with calls to read whatever is stored
685+ # in the `Ref`s that they are associated to.
686+ callable_inst_pairs = Mooncake. IDInstPair[]
687+ for (n, arg) in enumerate (callable_args)
688+ arg isa ID || continue
689+
690+ new_id = ID ()
691+ ref_ind = ssa_id_to_ref_index_map[arg]
692+ expr = Expr (:call , get_ref_at, refs_id, ref_ind)
693+ push! (callable_inst_pairs, (new_id, new_inst (expr)))
694+ callable_args[n] = new_id
695+ end
696+
664697 # Allocate a slot in the _refs vector for this callable.
665698 push! (_refs, Ref (callable))
666699 callable_ind = length (_refs)
667700
668701 # Retrieve the callable from the refs.
669702 callable_id = ID ()
670- callable = Expr (:call , get_ref_at, refs_id, callable_ind)
703+ callable_stmt = Expr (:call , get_ref_at, refs_id, callable_ind)
704+ push! (callable_inst_pairs, (callable_id, new_inst (callable_stmt)))
671705
672706 # Call the callable.
673- result = Expr (:call , callable_id, stmt. args[3 : end ]. .. )
674707 result_id = ID ()
708+ result_stmt = Expr (:call , callable_id, callable_args... )
709+ push! (callable_inst_pairs, (result_id, new_inst (result_stmt)))
675710
676711 # Determine whether this TapedTask has produced a not-a-`ProducedValue`.
677- not_produced = Expr (:call , not_a_produced, result_id)
678712 not_produced_id = ID ()
713+ not_produced_stmt = Expr (:call , not_a_produced, result_id)
714+ push! (callable_inst_pairs, (not_produced_id, new_inst (not_produced_stmt)))
679715
680716 # Go to a block which just returns the `ProducedValue`, if a
681717 # `ProducedValue` is returned, otherwise continue to the next split.
682718 is_produced_block_id = ID ()
683- next_block_id = splits_ids[n + 1 ] # safe since the last split ends with a terminator
684- # switch = Switch(Any[not_produced_id], [is_produced_block_id], next_block_id)
685- switch = IDGotoIfNot (not_produced_id, is_produced_block_id)
686-
687- # Insert a new block to hold the three previous statements.
688- callable_inst_pairs = Mooncake. IDInstPair[
689- (callable_id, new_inst (callable)),
690- (result_id, new_inst (result)),
691- (not_produced_id, new_inst (not_produced)),
692- (ID (), new_inst (switch)),
693- ]
719+ is_not_produced_block_id = ID ()
720+ switch = Switch (
721+ Any[not_produced_id],
722+ [is_produced_block_id],
723+ is_not_produced_block_id,
724+ )
725+ push! (callable_inst_pairs, (ID (), new_inst (switch)))
726+
727+ # Push the above statements onto a new block.
694728 push! (new_blocks, BBlock (callable_block_id, callable_inst_pairs))
695729
696- goto_block = BBlock (ID (), [(ID (), new_inst (IDGotoNode (next_block_id)))])
697- push! (new_blocks, goto_block)
698-
699730 # Construct block which handles the case that we got a `ProducedValue`. If
700731 # this happens, it means that `callable` has more things to produce still.
701732 # This means that we need to call it again next time we enter this function.
@@ -709,6 +740,21 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
709740 (return_id, new_inst (ReturnNode (result_id))),
710741 ]
711742 push! (new_blocks, BBlock (is_produced_block_id, produced_block_inst_pairs))
743+
744+ # Construct block which handles the case that we did not get a
745+ # `ProducedValue`. In this case, we must first push the result to the `Ref`
746+ # associated to the call, and goto the next split.
747+ next_block_id = splits_ids[n + 1 ] # safe since the last split ends with a terminator
748+ result_ref_ind = ssa_id_to_ref_index_map[id]
749+ set_ref = Expr (:call , set_ref_at!, refs_id, result_ref_ind, result_id)
750+ not_produced_block_inst_pairs = Mooncake. IDInstPair[
751+ (ID (), new_inst (set_ref))
752+ (ID (), new_inst (IDGotoNode (next_block_id)))
753+ ]
754+ push! (
755+ new_blocks,
756+ BBlock (is_not_produced_block_id, not_produced_block_inst_pairs),
757+ )
712758 end
713759 return new_blocks
714760 end
784830
785831function (l:: LazyCallable )(args:: Vararg{Any,N} ) where {N}
786832 isdefined (l, :mc ) || construct_callable! (l)
787- return l. mc (args... )
833+ return l. mc (args[ 2 : end ] . .. )
788834end
789835
790836function construct_callable! (l:: LazyCallable{sig} ) where {sig}
@@ -793,3 +839,19 @@ function construct_callable!(l::LazyCallable{sig}) where {sig}
793839 l. position = pos
794840 return nothing
795841end
842+
843+ mutable struct DynamicCallable{V}
844+ cache:: V
845+ end
846+
847+ DynamicCallable () = DynamicCallable (Dict {Any,Any} ())
848+
849+ function (dynamic_callable:: DynamicCallable )(args:: Vararg{Any,N} ) where {N}
850+ sig = Mooncake. _typeof (args)
851+ callable = get (dynamic_callable. cache, sig, nothing )
852+ if callable === nothing
853+ callable = build_callable (sig)
854+ dynamic_callable. cache[sig] = callable
855+ end
856+ return callable[1 ](args[2 : end ]. .. )
857+ end
0 commit comments