Skip to content

Commit 0af8299

Browse files
Kenooxinabox
andauthored
Fix codegen of fwd varargs (#137)
* Fix Composites of Structs * Remove misleading comment * Fix varargs codegen Fixes the varargs issue seen in #134, but doesn't quite make nesting work yet. * Marked fixed tests as passing * Fix test --------- Co-authored-by: Frames White <[email protected]>
1 parent a8525ae commit 0af8299

File tree

5 files changed

+40
-7
lines changed

5 files changed

+40
-7
lines changed

src/codegen/forward.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ function fwd_transform!(ci, mi, nargs, N)
6363
end
6464

6565
meth = mi.def::Method
66-
nargs = Int(meth.nargs)
67-
for i = 1:nargs
68-
if meth.isva && i == nargs
66+
for i = 1:meth.nargs
67+
if meth.isva && i == meth.nargs
6968
args = map(i:(nargs+1)) do j::Int
7069
emit!(Expr(:call, getfield, SlotNumber(2), j))
7170
end

src/stage1/forward.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
44
partial(x::UniformTangent, i) = getfield(x, :val)
55
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
66
partial(x::AbstractZero, i) = x
7-
partial(x::CompositeBundle{N, B}, i) where {N, B} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...)
7+
partial(x::CompositeBundle{N, B}, i) where {N, B<:Tuple} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...)
8+
function partial(x::CompositeBundle{N, B}, i) where {N, B}
9+
# This is tangent for a struct, but fields partials are each stored in a plain tuple
10+
# so we add the names back using the primal `B`
11+
# TODO: If required this can be done as a `@generated` function so it is type-stable
12+
backing = NamedTuple{fieldnames(B)}(map(x->partial(x, i), getfield(x, :tup)))
13+
return Tangent{B, typeof(backing)}(backing)
14+
end
15+
16+
817
primal(x::AbstractTangentBundle) = x.primal
918
primal(z::ZeroTangent) = ZeroTangent()
1019

src/stage1/recurse_fwd.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode,
3535
Core.svec(:ff, :args), Core.svec(), :(∂☆passthrough(args)))
3636
end
3737

38-
# Check if we have an rrule for this function
3938
sig = Tuple{map(π, args)...}
4039
mthds = Base._methods_by_ftype(sig, -1, world)
4140
if mthds === nothing || length(mthds) != 1

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ function sin_twice_fwd(x)
154154
end
155155
end
156156
let var"'" = Diffractor.PrimeDerivativeFwd
157-
@test_broken sin_twice_fwd'(1.0) == sin'''(1.0)
157+
@test sin_twice_fwd'(1.0) == sin'''(1.0)
158158
end
159159

160160
# Regression tests
@@ -228,7 +228,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
228228
@test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
229229
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)
230230

231-
@test_broken gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad
231+
@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad
232232
exp_log(x) = exp(log(x))
233233
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
234234
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])

test/stage2_fwd.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module stage2_fwd
22
using Diffractor, Test, ChainRulesCore
3+
34
mysin(x) = sin(x)
45
let sin′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(mysin), Float64})
56
@test sin′(1.0) == cos(1.0)
@@ -17,4 +18,29 @@ module stage2_fwd
1718
@test isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64})
1819
@test self_minus′′(1.0) == 0.
1920
end
21+
22+
@testset "structs" begin
23+
struct Foo
24+
x
25+
y
26+
end
27+
foo_dub(x) = Foo(x, 2x)
28+
dz = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(foo_dub), Diffractor.TaylorBundle{1}(10.0, (π,)))
29+
@test Diffractor.first_partial(dz) == Tangent{Foo}(;x=π, y=2π)
30+
end
31+
32+
@testset "mix of vararg and positional args" begin
33+
cc(a, x::Vararg) = nothing
34+
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(cc), Diffractor.TaylorBundle{1}(10f0, (10.0,)), Diffractor.TaylorBundle{1}(10f0, (10.0,)))
35+
36+
gg(a, xs...) = nothing
37+
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(gg), Diffractor.TaylorBundle{1}(10f0, (1.2,)), Diffractor.TaylorBundle{1}(20f0, (1.1,)))
38+
end
39+
40+
41+
@testset "nontrivial nested" begin
42+
f(x) = 3x^2
43+
g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,)))
44+
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10f0, (1.0,)))
45+
end
2046
end

0 commit comments

Comments
 (0)