Skip to content

Commit 79b9848

Browse files
authored
Merge pull request #36 from JuliaDiff/ox/1pb
Remove special handling for pullbacks having multiple inputs
2 parents 1c01d10 + f17c05f commit 79b9848

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.2.7"
3+
version = "0.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.7.1"
14+
ChainRulesCore = "0.8"
1515
Compat = "3"
1616
FiniteDifferences = "0.9"
1717
julia = "1"

src/testers.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,8 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
147147
# use collect so can do vector equality
148148
@test isapprox(collect(y_ad), collect(y); rtol=rtol, atol=atol)
149149
@assert !(isa(ȳ, Thunk))
150-
# If the function returned multiple values,
151-
# then it must have multiple seeds for propagating backwards
152-
∂s = (y_ad isa Tuple) ? pullback(ȳ...) : pullback(ȳ)
150+
151+
∂s = pullback(ȳ)
153152
∂self = ∂s[1]
154153
x̄s_ad = ∂s[2:end]
155154
@test ∂self === NO_FIELDS # No internal fields

test/testers.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,28 @@ fbtestkws(x, y; err = true) = err ? error() : x
4949
end
5050
end
5151

52+
@testset "single input, multiple output" begin
53+
simo(x) = (x, 2x)
54+
function ChainRulesCore.rrule(simo, x)
55+
simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b)
56+
return simo(x), simo_pullback
57+
end
58+
function ChainRulesCore.frule((_, ẋ), simo, x)
59+
y = simo(x)
60+
return y, Composite{typeof(y)}(ẋ, 2ẋ)
61+
end
62+
63+
@testset "frule_test" begin
64+
frule_test(simo, (randn(), randn())) # on scalar
65+
frule_test(simo, (randn(4), randn(4))) # on array
66+
end
67+
@testset "rrule_test" begin
68+
# note: we are pulling back tuples (could use Composites here instead)
69+
rrule_test(simo, (randn(), rand()), (randn(), randn())) # on scalar
70+
rrule_test(simo, (randn(4), rand(4)), (randn(4), randn(4))) # on array
71+
end
72+
end
73+
5274

5375
@testset "tuple input: first" begin
5476
ChainRulesCore.frule((_, dx), ::typeof(first), xs::Tuple) = (first(xs), first(dx))

0 commit comments

Comments
 (0)