Skip to content

Commit 1d6f9a1

Browse files
committed
Fix a bug with using the return value of produce
1 parent baf1c50 commit 1d6f9a1

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/copyable_task.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,15 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}
898898
prod_val = deref_id
899899
end
900900

901+
# Set the ref for this statement, as we would for any other call or invoke.
902+
# The TapedTask may need to read this ref when it resumes, if the return
903+
# value of `produce` is used within the original function.
904+
if is_used_dict[id]
905+
out_ind = ssa_id_to_ref_index_map[id]
906+
set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, prod_val)
907+
push!(inst_pairs, (ID(), new_inst(set_ref)))
908+
end
909+
901910
# Construct a `ProducedValue`.
902911
val_id = ID()
903912
push!(inst_pairs, (val_id, new_inst(Expr(:call, ProducedValue, prod_val))))

test/copyable_task.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,16 @@
220220
g() = produce(rand() > -1.0 ? 2 : 0.1)
221221
@test Libtask.consume(Libtask.TapedTask(nothing, g)) == 2
222222
end
223+
224+
@testset "Return produce" begin
225+
# Test calling a function that does something with the return value of `produce`.
226+
# In this case it just returns it. This used to error, see
227+
# https://github.com/TuringLang/Libtask.jl/issues/190.
228+
produce_wrapper(x) = Libtask.produce(x)
229+
Libtask.might_produce(::Type{<:Tuple{typeof(produce_wrapper),Any}}) = true
230+
f(obs) = produce_wrapper(obs)
231+
tt = Libtask.TapedTask(nothing, f, :a)
232+
@test Libtask.consume(tt) === :a
233+
@test Libtask.consume(tt) === nothing
234+
end
223235
end

0 commit comments

Comments
 (0)