diff --git a/Project.toml b/Project.toml index 3a54c66..49271c3 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.9.2" +version = "0.9.3" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 4ba0f36..3dba030 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -898,6 +898,15 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} prod_val = deref_id end + # Set the ref for this statement, as we would for any other call or invoke. + # The TapedTask may need to read this ref when it resumes, if the return + # value of `produce` is used within the original function. + if is_used_dict[id] + out_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, prod_val) + push!(inst_pairs, (ID(), new_inst(set_ref))) + end + # Construct a `ProducedValue`. val_id = ID() push!(inst_pairs, (val_id, new_inst(Expr(:call, ProducedValue, prod_val)))) diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 6837e88..def48b5 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -220,4 +220,14 @@ g() = produce(rand() > -1.0 ? 2 : 0.1) @test Libtask.consume(Libtask.TapedTask(nothing, g)) == 2 end + + @testset "Return produce" begin + # Test calling a function that does something with the return value of `produce`. + # In this case it just returns it. This used to error, see + # https://github.com/TuringLang/Libtask.jl/issues/190. + f(obs) = produce(obs) + tt = Libtask.TapedTask(nothing, f, :a) + @test Libtask.consume(tt) === :a + @test Libtask.consume(tt) === nothing + end end