Skip to content

Commit 97c78eb

Browse files
committed
Fix calls to Varargs functions
1 parent 42036af commit 97c78eb

File tree

2 files changed

+96
-4
lines changed

2 files changed

+96
-4
lines changed

src/copyable_task.jl

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,31 @@ function callable_ret_type(sig, produce_types)
6363
return Union{Base.code_ircode_by_type(sig)[1][2],produce_type}
6464
end
6565

66+
"""
67+
check_varargs(sig, ir)
68+
69+
For a call signature `sig` and the IR code for it, check whether this is a Varargs call.
70+
71+
The need for this arises because the output of `Base.code_ircode_by_type` does not
72+
distinguish between varargs and tuples, see https://github.com/JuliaLang/julia/issues/58753. Thus we have to go back to the signature that created the IR to check. There are two cases
73+
that signal that this is indeed a varargs call:
74+
1. The last argument in `sig` is a `Vararg` object.
75+
2. The last argument of `ir` is a `Tuple` of the types of the last arguments of `sig`. For
76+
instance `sig` may end in `Symbol, Tuple{Int, Int}` and the last argument of `ir` would be
77+
`Tuple{Symbol, Tuple{Int, Int}}`.
78+
79+
That there are these two cases, and only these two cases, is not based on a good
80+
understanding of anything, but rather on observing which cases arise in our test suite. This solution is thus a hack and should be rewritten by someone who actually understands how IR
81+
handles `Varargs`.
82+
"""
83+
function check_varargs(sig, ir)
84+
sig.parameters[end] isa Core.TypeofVararg && return true
85+
(ir.argtypes[end] isa Type && ir.argtypes[end] <: Tuple) || return false
86+
ir_last_arg_types = ir.argtypes[end].parameters
87+
sig_last_arg_types = sig.parameters[(end - length(ir_last_arg_types) + 1):end]
88+
return sig_last_arg_types == ir_last_arg_types
89+
end
90+
6691
"""
6792
build_callable(sig::Type{<:Tuple})
6893
@@ -84,11 +109,12 @@ function build_callable(sig::Type{<:Tuple})
84109
return fresh_copy(mc_cache[key])
85110
else
86111
ir = Base.code_ircode_by_type(sig)[1][1]
112+
isva = check_varargs(sig, ir)
87113
bb, refs, types = derive_copyable_task_ir(BBCode(ir))
88114
unoptimised_ir = IRCode(bb)
89115
optimised_ir = optimise_ir!(unoptimised_ir)
90116
mc_ret_type = callable_ret_type(sig, types)
91-
mc = misty_closure(mc_ret_type, optimised_ir, refs...; do_compile=true)
117+
mc = misty_closure(mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true)
92118
mc_cache[key] = mc
93119
return mc, refs[end]
94120
end
@@ -315,7 +341,7 @@ end
315341
"""
316342
set_taped_globals!(t::TapedTask, new_taped_globals)::Nothing
317343
318-
Set the `taped_globals` of `t` to `new_taped_globals`. Any calls to
344+
Set the `taped_globals` of `t` to `new_taped_globals`. Any calls to
319345
[`get_taped_globals`](@ref) in future calls to `consume(t)` (either directly, or implicitly
320346
via iteration) will see this new value.
321347
"""
@@ -573,7 +599,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}
573599
# We enforced above the condition that the final statement in a basic block must not
574600
# produce. This ensures that the final split does not produce. While not strictly
575601
# necessary, this simplifies the implementation (see below).
576-
#
602+
#
577603
# As a result of the above, a basic block will be associated to exactly one split if it
578604
# does not contain any statements which may produce.
579605
#
@@ -595,7 +621,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}
595621
# Owing to splitting blocks up, we will need to re-label some `GotoNode`s and
596622
# `GotoIfNot`s. To understand this, consider the following block, whose original `ID`
597623
# we assume to be `ID(old_id)`.
598-
# ID(new_id) - %1 = φ(ID(3) => ...)
624+
# ID(new_id) - %1 = φ(ID(3) => ...)
599625
# ID(new_id) - %3 = call_which_must_not_produce(...)
600626
# ID(new_id) - %4 = produce(%3)
601627
# ID(old_id) - GotoNode(ID(5))

src/test_utils.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,38 @@ function test_cases()
170170
[true, 1],
171171
none,
172172
),
173+
Testcase(
174+
"nested with args (static)",
175+
nothing,
176+
(static_nested_outer_args,),
177+
nothing,
178+
[:a, :b, false],
179+
none,
180+
),
181+
Testcase(
182+
"nested with args (static + used)",
183+
nothing,
184+
(static_nested_outer_use_produced_args,),
185+
nothing,
186+
[:a, :b, 1],
187+
none,
188+
),
189+
Testcase(
190+
"nested with args (dynamic)",
191+
nothing,
192+
(dynamic_nested_outer_args, Ref{Any}(nested_inner_args)),
193+
nothing,
194+
[:a, :b, false],
195+
none,
196+
),
197+
Testcase(
198+
"nested with args (dynamic + used)",
199+
nothing,
200+
(dynamic_nested_outer_use_produced_args, Ref{Any}(nested_inner_args)),
201+
nothing,
202+
[:a, :b, 1],
203+
none,
204+
),
173205
Testcase(
174206
"callable struct", nothing, (CallableStruct(5), 4), nothing, [5, 4, 9], allocs
175207
),
@@ -334,6 +366,40 @@ function dynamic_nested_outer_use_produced(f::Ref{Any})
334366
return nothing
335367
end
336368

369+
@noinline function nested_inner_args(xs...)
370+
for x in xs
371+
produce(x)
372+
end
373+
return 1
374+
end
375+
376+
Libtask.might_produce(::Type{<:Tuple{typeof(nested_inner_args),Any}}) = true
377+
Libtask.might_produce(::Type{<:Tuple{typeof(nested_inner_args),Any,Vararg}}) = true
378+
379+
function static_nested_outer_args()
380+
nested_inner_args(:a, :b)
381+
produce(false)
382+
return nothing
383+
end
384+
385+
function static_nested_outer_use_produced_args()
386+
y = nested_inner_args(:a, :b)
387+
produce(y)
388+
return nothing
389+
end
390+
391+
function dynamic_nested_outer_args(f::Ref{Any})
392+
f[](:a, :b)
393+
produce(false)
394+
return nothing
395+
end
396+
397+
function dynamic_nested_outer_use_produced_args(f::Ref{Any})
398+
y = f[](:a, :b)
399+
produce(y)
400+
return nothing
401+
end
402+
337403
struct CallableStruct{T}
338404
x::T
339405
end

0 commit comments

Comments
 (0)