From 1d6f9a1ed4f6eceb20eea01acbec72ef038634c5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 8 Jul 2025 16:08:35 +0100 Subject: [PATCH 1/3] Fix a bug with using the return value of produce --- src/copyable_task.jl | 9 +++++++++ test/copyable_task.jl | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 4ba0f36..3dba030 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -898,6 +898,15 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} prod_val = deref_id end + # Set the ref for this statement, as we would for any other call or invoke. + # The TapedTask may need to read this ref when it resumes, if the return + # value of `produce` is used within the original function. + if is_used_dict[id] + out_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, prod_val) + push!(inst_pairs, (ID(), new_inst(set_ref))) + end + # Construct a `ProducedValue`. val_id = ID() push!(inst_pairs, (val_id, new_inst(Expr(:call, ProducedValue, prod_val)))) diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 6837e88..67205e0 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -220,4 +220,16 @@ g() = produce(rand() > -1.0 ? 2 : 0.1) @test Libtask.consume(Libtask.TapedTask(nothing, g)) == 2 end + + @testset "Return produce" begin + # Test calling a function that does something with the return value of `produce`. + # In this case it just returns it. This used to error, see + # https://github.com/TuringLang/Libtask.jl/issues/190. + produce_wrapper(x) = Libtask.produce(x) + Libtask.might_produce(::Type{<:Tuple{typeof(produce_wrapper),Any}}) = true + f(obs) = produce_wrapper(obs) + tt = Libtask.TapedTask(nothing, f, :a) + @test Libtask.consume(tt) === :a + @test Libtask.consume(tt) === nothing + end end From 1909958207436fe2dabc50d28ff4eb19b80b83e5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 8 Jul 2025 16:09:09 +0100 Subject: [PATCH 2/3] Bump patch version to 0.9.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3a54c66..49271c3 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.9.2" +version = "0.9.3" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" From ada7e0c3a3c4614772de7ef943ef28cb439526fc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 8 Jul 2025 16:13:16 +0100 Subject: [PATCH 3/3] Simplify test --- test/copyable_task.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 67205e0..def48b5 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -225,9 +225,7 @@ # Test calling a function that does something with the return value of `produce`. # In this case it just returns it. This used to error, see # https://github.com/TuringLang/Libtask.jl/issues/190. - produce_wrapper(x) = Libtask.produce(x) - Libtask.might_produce(::Type{<:Tuple{typeof(produce_wrapper),Any}}) = true - f(obs) = produce_wrapper(obs) + f(obs) = produce(obs) tt = Libtask.TapedTask(nothing, f, :a) @test Libtask.consume(tt) === :a @test Libtask.consume(tt) === nothing