@@ -241,22 +241,34 @@ function is_produce_stmt(x)::Bool
241
241
end
242
242
243
243
"""
244
- stmt_might_produce(x)::Bool
244
+ stmt_might_produce(x, ret_type::Type )::Bool
245
245
246
246
`true` if `x` might contain a call to `produce`, and `false` otherwise.
247
247
"""
248
- function stmt_might_produce (x):: Bool
248
+ function stmt_might_produce (x, ret_type:: Type ):: Bool
249
+
250
+ # Statement will terminate in an unusual fashion, so don't bother recursing.
251
+ # This isn't _strictly_ correct (there could be a `produce` statement before the
252
+ # `throw` call is hit), but this seems unlikely to happen in practice. If it does, the
253
+ # user should get a sensible error message anyway.
254
+ ret_type == Union{} && return false
255
+
256
+ # Statement will terminate in the usual fashion, so _do_ bother recusing.
249
257
is_produce_stmt (x) && return true
250
258
Meta. isexpr (x, :invoke ) && return might_produce (x. args[1 ]. specTypes)
259
+ if Meta. isexpr (x, :call )
260
+ # This is a hack -- it's perfectly possible for `DataType` calls to produce in general.
261
+ f = get_function (x. args[1 ])
262
+ _might_produce = ! isa (f, Union{Core. IntrinsicFunction,Core. Builtin,DataType})
263
+ return _might_produce
264
+ end
251
265
return false
252
-
253
- # # TODO : make this correct
254
- # Meta.isexpr(x, :call) &&
255
- # return !isa(x.args[1], Union{Core.IntrinsicFunction,Core.Builtin})
256
- # Meta.isexpr(x, :invoke) && return false # todo: make this more accurate
257
- # return false
258
266
end
259
267
268
+ get_function (x) = x
269
+ get_function (x:: Expr ) = eval (x)
270
+ get_function (x:: GlobalRef ) = isconst (x) ? getglobal (x. mod, x. name) : x. binding
271
+
260
272
"""
261
273
produce_value(x::Expr)
262
274
@@ -347,7 +359,11 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
347
359
# in the block at which the corresponding split starts and finishes.
348
360
all_splits = map (ir. blocks) do block
349
361
split_ends = vcat (
350
- findall (inst -> stmt_might_produce (inst. stmt), block. insts), length (block)
362
+ findall (
363
+ inst -> stmt_might_produce (inst. stmt, CC. widenconst (inst. type)),
364
+ block. insts,
365
+ ),
366
+ length (block),
351
367
)
352
368
return map (enumerate (split_ends)) do (n, split_end)
353
369
return (start= (n == 1 ? 0 : split_ends[n - 1 ]) + 1 , last= split_end)
@@ -654,48 +670,63 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
654
670
push! (new_blocks, BBlock (splits_ids[n], inst_pairs))
655
671
656
672
# Derive TapedTask for this statement.
657
- callable = if Meta. isexpr (stmt, :invoke )
673
+ ( callable, callable_args) = if Meta. isexpr (stmt, :invoke )
658
674
sig = stmt. args[1 ]. specTypes
659
- LazyCallable {sig,callable_ret_type(sig)} ()
675
+ (LazyCallable {sig,callable_ret_type(sig)} (), stmt. args[2 : end ])
676
+ elseif Meta. isexpr (stmt, :call )
677
+ (DynamicCallable (), stmt. args)
660
678
else
679
+ display (stmt)
680
+ println ()
661
681
error (" unhandled statement which might produce $stmt " )
662
682
end
663
683
684
+ # Find any `ID`s and replace them with calls to read whatever is stored
685
+ # in the `Ref`s that they are associated to.
686
+ callable_inst_pairs = Mooncake. IDInstPair[]
687
+ for (n, arg) in enumerate (callable_args)
688
+ arg isa ID || continue
689
+
690
+ new_id = ID ()
691
+ ref_ind = ssa_id_to_ref_index_map[arg]
692
+ expr = Expr (:call , get_ref_at, refs_id, ref_ind)
693
+ push! (callable_inst_pairs, (new_id, new_inst (expr)))
694
+ callable_args[n] = new_id
695
+ end
696
+
664
697
# Allocate a slot in the _refs vector for this callable.
665
698
push! (_refs, Ref (callable))
666
699
callable_ind = length (_refs)
667
700
668
701
# Retrieve the callable from the refs.
669
702
callable_id = ID ()
670
- callable = Expr (:call , get_ref_at, refs_id, callable_ind)
703
+ callable_stmt = Expr (:call , get_ref_at, refs_id, callable_ind)
704
+ push! (callable_inst_pairs, (callable_id, new_inst (callable_stmt)))
671
705
672
706
# Call the callable.
673
- result = Expr (:call , callable_id, stmt. args[3 : end ]. .. )
674
707
result_id = ID ()
708
+ result_stmt = Expr (:call , callable_id, callable_args... )
709
+ push! (callable_inst_pairs, (result_id, new_inst (result_stmt)))
675
710
676
711
# Determine whether this TapedTask has produced a not-a-`ProducedValue`.
677
- not_produced = Expr (:call , not_a_produced, result_id)
678
712
not_produced_id = ID ()
713
+ not_produced_stmt = Expr (:call , not_a_produced, result_id)
714
+ push! (callable_inst_pairs, (not_produced_id, new_inst (not_produced_stmt)))
679
715
680
716
# Go to a block which just returns the `ProducedValue`, if a
681
717
# `ProducedValue` is returned, otherwise continue to the next split.
682
718
is_produced_block_id = ID ()
683
- next_block_id = splits_ids[n + 1 ] # safe since the last split ends with a terminator
684
- # switch = Switch(Any[not_produced_id], [is_produced_block_id], next_block_id)
685
- switch = IDGotoIfNot (not_produced_id, is_produced_block_id)
686
-
687
- # Insert a new block to hold the three previous statements.
688
- callable_inst_pairs = Mooncake. IDInstPair[
689
- (callable_id, new_inst (callable)),
690
- (result_id, new_inst (result)),
691
- (not_produced_id, new_inst (not_produced)),
692
- (ID (), new_inst (switch)),
693
- ]
719
+ is_not_produced_block_id = ID ()
720
+ switch = Switch (
721
+ Any[not_produced_id],
722
+ [is_produced_block_id],
723
+ is_not_produced_block_id,
724
+ )
725
+ push! (callable_inst_pairs, (ID (), new_inst (switch)))
726
+
727
+ # Push the above statements onto a new block.
694
728
push! (new_blocks, BBlock (callable_block_id, callable_inst_pairs))
695
729
696
- goto_block = BBlock (ID (), [(ID (), new_inst (IDGotoNode (next_block_id)))])
697
- push! (new_blocks, goto_block)
698
-
699
730
# Construct block which handles the case that we got a `ProducedValue`. If
700
731
# this happens, it means that `callable` has more things to produce still.
701
732
# This means that we need to call it again next time we enter this function.
@@ -709,6 +740,21 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
709
740
(return_id, new_inst (ReturnNode (result_id))),
710
741
]
711
742
push! (new_blocks, BBlock (is_produced_block_id, produced_block_inst_pairs))
743
+
744
+ # Construct block which handles the case that we did not get a
745
+ # `ProducedValue`. In this case, we must first push the result to the `Ref`
746
+ # associated to the call, and goto the next split.
747
+ next_block_id = splits_ids[n + 1 ] # safe since the last split ends with a terminator
748
+ result_ref_ind = ssa_id_to_ref_index_map[id]
749
+ set_ref = Expr (:call , set_ref_at!, refs_id, result_ref_ind, result_id)
750
+ not_produced_block_inst_pairs = Mooncake. IDInstPair[
751
+ (ID (), new_inst (set_ref))
752
+ (ID (), new_inst (IDGotoNode (next_block_id)))
753
+ ]
754
+ push! (
755
+ new_blocks,
756
+ BBlock (is_not_produced_block_id, not_produced_block_inst_pairs),
757
+ )
712
758
end
713
759
return new_blocks
714
760
end
784
830
785
831
function (l:: LazyCallable )(args:: Vararg{Any,N} ) where {N}
786
832
isdefined (l, :mc ) || construct_callable! (l)
787
- return l. mc (args... )
833
+ return l. mc (args[ 2 : end ] . .. )
788
834
end
789
835
790
836
function construct_callable! (l:: LazyCallable{sig} ) where {sig}
@@ -793,3 +839,19 @@ function construct_callable!(l::LazyCallable{sig}) where {sig}
793
839
l. position = pos
794
840
return nothing
795
841
end
842
+
843
+ mutable struct DynamicCallable{V}
844
+ cache:: V
845
+ end
846
+
847
+ DynamicCallable () = DynamicCallable (Dict {Any,Any} ())
848
+
849
+ function (dynamic_callable:: DynamicCallable )(args:: Vararg{Any,N} ) where {N}
850
+ sig = Mooncake. _typeof (args)
851
+ callable = get (dynamic_callable. cache, sig, nothing )
852
+ if callable === nothing
853
+ callable = build_callable (sig)
854
+ dynamic_callable. cache[sig] = callable
855
+ end
856
+ return callable[1 ](args[2 : end ]. .. )
857
+ end
0 commit comments