diff --git a/ext/AdvancedPSLibtaskExt.jl b/ext/AdvancedPSLibtaskExt.jl index d3ba9e0..67213b1 100644 --- a/ext/AdvancedPSLibtaskExt.jl +++ b/ext/AdvancedPSLibtaskExt.jl @@ -33,9 +33,11 @@ TapedGlobals(rng::Random.AbstractRNG) = TapedGlobals(rng, nothing) State wrapper to hold `Libtask.CTask` model initiated from `f`. """ function AdvancedPS.LibtaskModel( - f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... + f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG ) # Changed the API, need to take care of the RNG properly - return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...)) + return AdvancedPS.LibtaskModel( + f, Libtask.TapedTask(TapedGlobals(rng), f.fargs...; f.kwargs...) + ) end """ @@ -114,7 +116,9 @@ function AdvancedPS.forkr(trace::LibtaskTrace) newf = AdvancedPS.reset_model(trace.model.f) Random123.set_counter!(rng, 1) - ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf) + ctask = Libtask.TapedTask( + TapedGlobals(rng, get_other_global(trace)), newf.fargs...; newf.kwargs... + ) new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask) # add backward reference