diff --git a/Project.toml b/Project.toml index 7920b392..3a54c667 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.9.1" +version = "0.9.2" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 2334bb33..2210c5e1 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -72,6 +72,13 @@ the current world age, will make a copy of an existing `MistyClosure`. If not, will derive it from scratch (derive the IR + compile it etc). """ function build_callable(sig::Type{<:Tuple}) + if sig <: Tuple{typeof(produce),Any} + msg = """ + Can not construct a TapedTask for a 'naked' call to `produce`. + Please wrap the call to `produce` in a function, and construct a + TapedTask from that function.""" + throw(ArgumentError(msg)) + end key = CacheKey(Base.get_world_counter(), sig) if haskey(mc_cache, key) return fresh_copy(mc_cache[key]) @@ -367,8 +374,8 @@ get_value(x) = x expression, otherwise `false`. """ function is_produce_stmt(x)::Bool - if Meta.isexpr(x, :invoke) && length(x.args) == 3 - return get_value(x.args[2]) === produce + if Meta.isexpr(x, :invoke) && length(x.args) == 3 && x.args[1] isa Core.MethodInstance + return x.args[1].specTypes <: Tuple{typeof(produce),Any} elseif Meta.isexpr(x, :call) && length(x.args) == 2 return get_value(x.args[1]) === produce else diff --git a/test/copyable_task.jl b/test/copyable_task.jl index f6562993..6837e881 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -138,6 +138,12 @@ @test ex isa BoundsError end end + + @testset "Naked produce" begin + @test_throws "wrap the call to `produce` in a function" Libtask.consume( + Libtask.TapedTask(nothing, Libtask.produce, 0) + ) + end end @testset "copying" begin @@ -209,4 +215,9 @@ end @test ex === nothing end + + @testset "Issue #185" begin + g() = produce(rand() > -1.0 ? 2 : 0.1) + @test Libtask.consume(Libtask.TapedTask(nothing, g)) == 2 + end end