Skip to content

Commit 843bc83

Browse files
committed
Add @might_produce_kwargs macro
1 parent f154425 commit 843bc83

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
33
license = "MIT"
44
desc = "Tape based task copying in Turing"
55
repo = "https://github.com/TuringLang/Libtask.jl.git"
6-
version = "0.9.4"
6+
version = "0.9.5"
77

88
[deps]
99
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"

src/copyable_task.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,58 @@ end
354354
`true` if a call to method with signature `sig` is permitted to contain
355355
`Libtask.produce` statements.
356356
357-
This is an opt-in mechanism. the fallback method of this function returns `false` indicating
357+
This is an opt-in mechanism. The fallback method of this function returns `false` indicating
358358
that, by default, we assume that calls do not contain `Libtask.produce` statements.
359359
"""
360360
might_produce(::Type{<:Tuple}) = false
361361

362+
"""
363+
@might_produce_kwargs(f)
364+
365+
If `f` is a function that has keyword arguments and may call `Libtask.produce` inside it,
366+
then `@might_produce_kwargs(f)` will generate the appropriate methods needed to ensure that
367+
`Libtask.might_produce` returns `true` for the relevant signatures of `f`.
368+
369+
```jldoctest kwargs
370+
julia> # For this demonstration we need to mark `g` as not being inlineable.
371+
@noinline function g(x; y, z=0)
372+
produce(x + y + z)
373+
end
374+
g (generic function with 1 method)
375+
376+
julia> function f()
377+
g(1; y=2, z=3)
378+
end
379+
f (generic function with 1 method)
380+
381+
julia> # This returns nothing because `g` isn't yet marked as being able to `produce`.
382+
consume(Libtask.TapedTask(nothing, f))
383+
384+
julia> Libtask.@might_produce_kwargs(g)
385+
386+
julia> # Now it works!
387+
consume(Libtask.TapedTask(nothing, f))
388+
6
389+
"""
390+
macro might_produce_kwargs(f)
391+
# See https://github.com/TuringLang/Libtask.jl/issues/197 for discussion of this macro.
392+
quote
393+
possible_n_kwargs = unique(map(length Base.kwarg_decl, methods($(esc(f)))))
394+
if possible_n_kwargs != [0]
395+
# Oddly we need to interpolate the module and not the function: either
396+
# `$(might_produce)` or $(Libtask.might_produce) seem more natural but both of
397+
# those cause the entire `Libtask.might_produce` to be treated as a single
398+
# symbol. See https://discourse.julialang.org/t/128613
399+
$(Libtask).might_produce(::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,typeof($(esc(f))),Vararg}}) = true
400+
for n in possible_n_kwargs
401+
# We only need `Any` and not `<:Any` because tuples are covariant.
402+
kwarg_types = fill(Any, n)
403+
$(Libtask).might_produce(::Type{<:Tuple{<:Function,kwarg_types...,typeof($(esc(f))),Vararg}}) = true
404+
end
405+
end
406+
end
407+
end
408+
362409
# Helper struct used in `derive_copyable_task_ir`.
363410
struct TupleRef
364411
n::Int

0 commit comments

Comments
 (0)