@@ -39,9 +39,9 @@ function build_callable(sig::Type{<:Tuple})
3939 return mc, refs[end ]
4040end
4141
42- mutable struct TapedTask{Tdynamic_scope,Targs ,Tmc<: MistyClosure }
42+ mutable struct TapedTask{Tdynamic_scope,Tfargs ,Tmc<: MistyClosure }
4343 dynamic_scope:: Tdynamic_scope
44- args :: Targs
44+ fargs :: Tfargs
4545 const mc:: Tmc
4646 const position:: Base.RefValue{Int32}
4747end
@@ -165,7 +165,7 @@ julia> consume(t)
165165function TapedTask (dynamic_scope:: Any , fargs... )
166166 seed_id! ()
167167 mc, count_ref = build_callable (typeof (fargs))
168- return TapedTask (dynamic_scope, fargs[ 2 : end ] , mc, count_ref)
168+ return TapedTask (dynamic_scope, fargs, mc, count_ref)
169169end
170170
171171"""
@@ -199,7 +199,7 @@ called, it start execution from the entry point. If `consume` has previously bee
199199`nothing` will be returned.
200200"""
201201@inline function consume (t:: TapedTask )
202- v = with (() -> t. mc (t. args ... ), dynamic_scope => t. dynamic_scope)
202+ v = with (() -> t. mc (t. fargs ... ), dynamic_scope => t. dynamic_scope)
203203 return v isa ProducedValue ? v[] : nothing
204204end
205205
@@ -287,12 +287,45 @@ end
287287
288288@inline Base. getindex (x:: ProducedValue ) = x. x
289289
290+ """
291+ inc_args(stmt)
292+
293+ Increment by `1` the `n` field of any `Argument`s present in `stmt`.
294+ Used in `make_ad_stmts!`.
295+ """
296+ inc_args (x:: Expr ) = Expr (x. head, map (__inc, x. args)... )
297+ inc_args (x:: ReturnNode ) = isdefined (x, :val ) ? ReturnNode (__inc (x. val)) : x
298+ inc_args (x:: IDGotoIfNot ) = IDGotoIfNot (__inc (x. cond), x. dest)
299+ inc_args (x:: IDGotoNode ) = x
300+ function inc_args (x:: IDPhiNode )
301+ new_values = Vector {Any} (undef, length (x. values))
302+ for n in eachindex (x. values)
303+ if isassigned (x. values, n)
304+ new_values[n] = __inc (x. values[n])
305+ end
306+ end
307+ return IDPhiNode (x. edges, new_values)
308+ end
309+ inc_args (:: Nothing ) = nothing
310+ inc_args (x:: GlobalRef ) = x
311+ inc_args (x:: Core.PiNode ) = Core. PiNode (__inc (x. val), __inc (x. typ))
312+
313+ __inc (x:: Argument ) = Argument (x. n + 1 )
314+ __inc (x) = x
315+
290316function derive_copyable_task_ir (ir:: BBCode ):: Tuple{BBCode,Tuple}
291317
292318 # The location from which all state can be retrieved. Since we're using `OpaqueClosure`s
293319 # to implement `TapedTask`s, this appears via the first argument.
294320 refs_id = Argument (1 )
295321
322+ # Increment all arguments by 1.
323+ for bb in ir. blocks, (n, inst) in enumerate (bb. insts)
324+ bb. insts[n] = CC. NewInstruction (
325+ inc_args (inst. stmt), inst. type, inst. info, inst. line, inst. flag
326+ )
327+ end
328+
296329 # Construct map between SSA IDs and their index in the state data structure and back.
297330 # Also construct a map from each ref index to its type. We only construct `Ref`s
298331 # for statements which return a value e.g. `IDGotoIfNot`s do not have a meaningful
@@ -778,7 +811,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
778811 # rather than nothing at all.
779812 new_argtypes = copy (ir. argtypes)
780813 refs = (_refs... , Ref {Int32} (- 1 ))
781- new_argtypes[ 1 ] = typeof (refs)
814+ new_argtypes = vcat ( typeof (refs), copy (ir . argtypes) )
782815
783816 # Return BBCode and the `Ref`s.
784817 return BBCode (new_bblocks, new_argtypes, ir. sptypes, ir. linetable, ir. meta), refs
830863
831864function (l:: LazyCallable )(args:: Vararg{Any,N} ) where {N}
832865 isdefined (l, :mc ) || construct_callable! (l)
833- return l. mc (args[ 2 : end ] . .. )
866+ return l. mc (args... )
834867end
835868
836869function construct_callable! (l:: LazyCallable{sig} ) where {sig}
@@ -853,5 +886,5 @@ function (dynamic_callable::DynamicCallable)(args::Vararg{Any,N}) where {N}
853886 callable = build_callable (sig)
854887 dynamic_callable. cache[sig] = callable
855888 end
856- return callable[1 ](args[ 2 : end ] . .. )
889+ return callable[1 ](args... )
857890end
0 commit comments