diff --git a/Project.toml b/Project.toml index 4a7eed2b..744e97cd 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.9.4" +version = "0.9.5" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" diff --git a/docs/src/index.md b/docs/src/index.md index 1f2b2b57..5f8edada 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -28,4 +28,5 @@ An opt-in mechanism marks functions that might contain `Libtask.produce` stateme ```@docs; canonical=true Libtask.might_produce(::Type{<:Tuple}) +Libtask.@might_produce ``` diff --git a/perf/p0.jl b/perf/p0.jl index 23a71250..c317b885 100644 --- a/perf/p0.jl +++ b/perf/p0.jl @@ -14,7 +14,7 @@ end # Case 1: Sample from the prior. rng = MersenneTwister() -m = Turing.Core.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng) +m = Turing.Inference.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng) f = m.evaluator[1]; args = m.evaluator[2:end]; @@ -27,7 +27,7 @@ println("Run a tape...") @btime t.tf(args...) # Case 2: SMC sampler -m = Turing.Core.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng) +m = Turing.Inference.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng) f = m.evaluator[1]; args = m.evaluator[2:end]; diff --git a/perf/p2.jl b/perf/p2.jl index bd904c33..6a883411 100644 --- a/perf/p2.jl +++ b/perf/p2.jl @@ -52,7 +52,7 @@ Random.seed!(rng, 2) iterations = 500 model_fun = infiniteGMM(data) -m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng) +m = Turing.Inference.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng) f = m.evaluator[1] args = m.evaluator[2:end] diff --git a/src/copyable_task.jl b/src/copyable_task.jl index ac549015..fe8b549d 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -354,11 +354,70 @@ end `true` if a call to method with signature `sig` is permitted to contain `Libtask.produce` statements. -This is an opt-in mechanism. the fallback method of this function returns `false` indicating +This is an opt-in mechanism. The fallback method of this function returns `false` indicating that, by default, we assume that calls do not contain `Libtask.produce` statements. """ might_produce(::Type{<:Tuple}) = false +""" + @might_produce(f) + +If `f` is a function that may call `Libtask.produce` inside it, then `@might_produce(f)` +will generate the appropriate methods needed to ensure that `Libtask.might_produce` returns +`true` for all relevant signatures of `f`. This works even if `f` has methods with keyword +arguments. + +```jldoctest might_produce_macro +julia> # For this demonstration we need to mark `g` as not being inlineable. + @noinline function g(x; y, z=0) + produce(x + y + z) + end +g (generic function with 1 method) + +julia> function f() + g(1; y=2, z=3) + end +f (generic function with 1 method) + +julia> # This returns nothing because `g` isn't yet marked as being able to `produce`. + consume(Libtask.TapedTask(nothing, f)) + +julia> Libtask.@might_produce(g) + +julia> # Now it works! + consume(Libtask.TapedTask(nothing, f)) +6 +""" +macro might_produce(f) + # See https://github.com/TuringLang/Libtask.jl/issues/197 for discussion of this macro. + quote + function $(Libtask).might_produce(::Type{<:Tuple{typeof($(esc(f))),Vararg}}) + return true + end + possible_n_kwargs = unique(map(length ∘ Base.kwarg_decl, methods($(esc(f))))) + if possible_n_kwargs != [0] + # Oddly we need to interpolate the module and not the function: either + # `$(might_produce)` or $(Libtask.might_produce) seem more natural but both of + # those cause the entire `Libtask.might_produce` to be treated as a single + # symbol. See https://discourse.julialang.org/t/128613 + function $(Libtask).might_produce( + ::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,typeof($(esc(f))),Vararg}} + ) + return true + end + for n in possible_n_kwargs + # We only need `Any` and not `<:Any` because tuples are covariant. + kwarg_types = fill(Any, n) + function $(Libtask).might_produce( + ::Type{<:Tuple{<:Function,kwarg_types...,typeof($(esc(f))),Vararg}} + ) + return true + end + end + end + end +end + # Helper struct used in `derive_copyable_task_ir`. struct TupleRef n::Int diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 1b051daa..9ae5377e 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -251,4 +251,53 @@ @test Libtask.consume(tt) === :a @test Libtask.consume(tt) === nothing end + + @testset "@might_produce macro" begin + # Positional arguments only + @noinline g1(x) = produce(x) + f1(x) = g1(x) + # Without marking it as might_produce + tt = Libtask.TapedTask(nothing, f1, 0) + @test Libtask.consume(tt) === nothing + # Now marking it + Libtask.@might_produce(g1) + tt = Libtask.TapedTask(nothing, f1, 0) + @test Libtask.consume(tt) === 0 + @test Libtask.consume(tt) === nothing + + # Keyword arguments only + @noinline g2(x; y=1, z=2) = produce(x + y + z) + f2(x) = g2(x) + # Without marking it as might_produce + tt = Libtask.TapedTask(nothing, f2, 0) + @test Libtask.consume(tt) === nothing + # Now marking it + Libtask.@might_produce(g2) + tt = Libtask.TapedTask(nothing, f2, 0) + @test Libtask.consume(tt) === 3 + @test Libtask.consume(tt) === nothing + + # A function with multiple methods. + # The function reference is used to ensure that it really doesn't get inlined + # (otherwise, for reasons that are yet unknown, these functions do get inlined when + # inside a testset) + @noinline g3(x) = produce(x) + @noinline g3(x, y; z) = produce(x + y + z) + @noinline g3(x, y, z; p, q) = produce(x + y + z + p + q) + function f3(x, fref) + fref[](x) + fref[](x, 1; z=2) + fref[](x, 1, 2; p=3, q=4) + return nothing + end + tt = Libtask.TapedTask(nothing, f3, 0, Ref(g3)) + @test Libtask.consume(tt) === nothing + # Now marking it + Libtask.@might_produce(g3) + tt = Libtask.TapedTask(nothing, f3, 0, Ref(g3)) + @test Libtask.consume(tt) === 0 + @test Libtask.consume(tt) === 3 + @test Libtask.consume(tt) === 10 + @test Libtask.consume(tt) === nothing + end end