@@ -35,7 +35,32 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
3535function AdvancedPS. LibtaskModel(
3636 f:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
3737) # Changed the API, need to take care of the RNG properly
38- return AdvancedPS. LibtaskModel(f, Libtask. TapedTask(TapedGlobals(rng), f, args... ))
38+ return AdvancedPS. LibtaskModel(
39+ f, Libtask. TapedTask(TapedGlobals(rng), f, args... )
40+ )
41+ end
42+ # TODO : Upstream this to Turing
43+ function AdvancedPS. LibtaskModel(
44+ f:: AdvancedPS.AbstractTuringLibtaskModel , rng:: Random.AbstractRNG
45+ )
46+ return AdvancedPS. LibtaskModel(
47+ f, Libtask. TapedTask(TapedGlobals(rng), f. fargs... ; f. kwargs... )
48+ )
49+ end
50+
51+ const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
52+
53+ function to_tapedtask(
54+ newf:: AdvancedPS.AbstractGenericModel , trace:: LibtaskTrace , rng:: Random.AbstractRNG
55+ )
56+ return Libtask. TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
57+ end
58+ function to_tapedtask(
59+ newf:: AdvancedPS.AbstractTuringLibtaskModel , trace:: LibtaskTrace , rng:: Random.AbstractRNG
60+ )
61+ return Libtask. TapedTask(
62+ TapedGlobals(rng, get_other_global(trace)), newf. fargs... ; newf. kwargs...
63+ )
3964end
4065
4166"""
@@ -47,8 +72,6 @@ function Base.copy(model::AdvancedPS.LibtaskModel)
4772 return AdvancedPS. LibtaskModel(deepcopy(model. f), copy(model. ctask))
4873end
4974
50- const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
51-
5275function Base. copy(trace:: LibtaskTrace )
5376 newtrace = AdvancedPS. Trace(copy(trace. model), deepcopy(trace. rng))
5477 set_other_global!(newtrace, newtrace)
@@ -114,7 +137,7 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
114137 newf = AdvancedPS. reset_model(trace. model. f)
115138 Random123. set_counter!(rng, 1 )
116139
117- ctask = Libtask . TapedTask(TapedGlobals(rng, get_other_global( trace)), newf )
140+ ctask = to_tapedtask(newf, trace, rng )
118141 new_tapedmodel = AdvancedPS. LibtaskModel(newf, ctask)
119142
120143 # add backward reference
0 commit comments