Skip to content

Commit 51bd9b5

Browse files
Fix vararg (#158)
* Fix vararg * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 626cb72 commit 51bd9b5

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

src/utils.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,24 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
8787
end
8888

8989
# TODO replace with `Base.invoke_within` if julia#52964 lands
90-
ir = first(only(
90+
ir, ty = only(
9191
# TODO fix it for kwargs
9292
Base.code_ircode(f, map(typeof, traced_args); interp),
93-
))
94-
93+
)
9594
oc = Core.OpaqueClosure(ir)
9695

9796
if f === Reactant.apply
9897
oc(traced_args[1], (traced_args[2:end]...,))
9998
else
100-
oc(traced_args...)
99+
if length(traced_args) + 1 != length(ir.argtypes)
100+
@assert ir.argtypes[end] <: Tuple
101+
oc(
102+
traced_args[1:(length(ir.argtypes) - 2)]...,
103+
(traced_args[(length(ir.argtypes) - 1):end]...,),
104+
)
105+
else
106+
oc(traced_args...)
107+
end
101108
end
102109
end
103110

test/basic.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,17 @@ bcast_cos(x) = cos.(x)
105105
@test r cos.(x)
106106
end
107107

108+
f_var(args...) = sum(args)
109+
110+
@testset "Vararg" begin
111+
x = Reactant.to_rarray(ones(3))
112+
y = Reactant.to_rarray(3 * ones(3))
113+
z = Reactant.to_rarray(2.6 * ones(3))
114+
115+
f2 = @compile f_var(x, y, z)
116+
@test f2(x, y, z) [6.6, 6.6, 6.6]
117+
end
118+
108119
function sumcos(x)
109120
return sum(cos.(x))
110121
end

0 commit comments

Comments
 (0)