Skip to content

Commit 3c18e86

Browse files
devmotionsethaxengdalle
authored
Update of #36 and #35 (#93)
* Define value_and_pullback_function as returning value * Update definition of Jacobian * Update tests * Update API definition * Increment minor version number * Make value_and_oullback_function the primitive * Define pullback_function in terms of value_and_pullback_function * Define value_and_pullback_function in tests * Increment version number * Use value_and_pb_function in CRC and Tracker extensions * Fix bug in Tracker backend * More updates * Fix end * Handle `nothing` * Fix test failures * Use named functions --------- Co-authored-by: Seth Axen <[email protected]> Co-authored-by: Guillaume Dalle <[email protected]>
1 parent 00181f8 commit 3c18e86

9 files changed

+92
-73
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AbstractDifferentiation"
22
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
33
authors = ["Mohamed Tarek <[email protected]> and contributors"]
4-
version = "0.5.3"
4+
version = "0.6.0-DEV"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ This operation goes by a few names. Refer to the [ChainRules documentation](http
9292

9393
The following functions can be used to request the pullback operator/function with or without the function value. In order to request the pullback function `pb_f` of a function `f` at the inputs `xs`, you can use either of:
9494
- `pb_f = AD.pullback_function(ab::AD.AbstractBackend, f, xs...)`: returns the pullback function `pb_f` of the function `f` at the inputs `xs`. `pb_f` is a function that accepts the co-tangents `vs` as input which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `pb_f` can also accept a single input instead of a 1-tuple.
95-
- `value_and_pb_f = AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)`: returns a function `value_and_pb_f` which accepts the co-tangent `vs` as input which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `value_and_pb_f` can accept a single input instead of a 1-tuple. `value_and_pb_f` returns a 2-tuple, namely the value `f(xs...)` and output of the pullback operator.
95+
- `value_and_pb_f = AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)`: computes the function value `v = f(xs...)` and returns a 2-tuple containing the value `v` and a function `pb_f` that accepts the co-tangent `vs` as input, which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `pb_f` can accept a single input instead of a 1-tuple.
9696

9797
### Lazy operators
9898

ext/AbstractDifferentiationChainRulesCoreExt.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@ module AbstractDifferentiationChainRulesCoreExt
33
import AbstractDifferentiation as AD
44
using ChainRulesCore: ChainRulesCore
55

6-
AD.@primitive function pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
7-
_, back = ChainRulesCore.rrule_via_ad(AD.ruleconfig(ba), f, xs...)
8-
pullback(vs) = Base.tail(back(vs))
9-
pullback(vs::Tuple{Any}) = Base.tail(back(first(vs)))
10-
return pullback
6+
AD.@primitive function value_and_pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
7+
value, back = ChainRulesCore.rrule_via_ad(AD.ruleconfig(ba), f, xs...)
8+
function rrule_pullback(vs)
9+
_vs = if vs isa Tuple && !(value isa Tuple)
10+
only(vs)
11+
else
12+
vs
13+
end
14+
return Base.tail(back(_vs))
15+
end
16+
return value, rrule_pullback
1117
end
1218

1319
end # module

ext/AbstractDifferentiationFiniteDifferencesExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ function AD.jacobian(ba::AD.FiniteDifferencesBackend, f, xs...)
1919
return FiniteDifferences.jacobian(ba.method, f, xs...)
2020
end
2121

22+
function AD.gradient(ba::AD.FiniteDifferencesBackend, f, xs...)
23+
return FiniteDifferences.grad(ba.method, f, xs...)
24+
end
25+
2226
function AD.pushforward_function(ba::AD.FiniteDifferencesBackend, f, xs...)
2327
return function pushforward(vs)
2428
ws = FiniteDifferences.jvp(ba.method, f, tuple.(xs, vs)...)
@@ -32,6 +36,15 @@ function AD.pullback_function(ba::AD.FiniteDifferencesBackend, f, xs...)
3236
end
3337
end
3438

39+
# Ensure consistency with `value_and_pullback` function
40+
function AD.value_and_pullback_function(ba::AD.FiniteDifferencesBackend, f, xs...)
41+
value = f(xs...)
42+
function fd_pullback(vs)
43+
return FiniteDifferences.j′vp(ba.method, f, vs, xs...)
44+
end
45+
return value, fd_pullback
46+
end
47+
3548
# Better performance: issue #87
3649
function AD.derivative(ba::AD.FiniteDifferencesBackend, f::TF, x::Real) where {TF<:Function}
3750
return (ba.method(f, x),)

ext/AbstractDifferentiationTrackerExt.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,18 @@ AD.primal_value(x::Tracker.TrackedReal) = Tracker.data(x)
1616
AD.primal_value(x::Tracker.TrackedArray) = Tracker.data(x)
1717
AD.primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x)
1818

19-
AD.@primitive function pullback_function(ba::AD.TrackerBackend, f, xs...)
20-
value, back = Tracker.forward(f, xs...)
21-
function pullback(ws)
22-
if ws isa Tuple && !(value isa Tuple)
23-
map(Tracker.data, back(only(ws)))
19+
AD.@primitive function value_and_pullback_function(ba::AD.TrackerBackend, f, xs...)
20+
_value, back = Tracker.forward(f, xs...)
21+
value = map(Tracker.data, _value)
22+
function tracker_pullback(ws)
23+
_ws = if ws isa Tuple && !(value isa Tuple)
24+
only(ws)
2425
else
25-
map(Tracker.data, back(ws))
26+
ws
2627
end
28+
return map(Tracker.data, back(_ws))
2729
end
28-
return pullback
30+
return value, tracker_pullback
2931
end
3032

3133
function AD.derivative(::AD.TrackerBackend, f, xs::Number...)

src/AbstractDifferentiation.jl

Lines changed: 25 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -221,48 +221,28 @@ end
221221
end
222222

223223
function pullback_function(ab::AbstractBackend, f, xs...)
224-
return (ws) -> begin
225-
return gradient(lowest(ab), (xs...,) -> begin
226-
vs = f(xs...)
227-
if ws isa Tuple
228-
@assert length(vs) == length(ws)
229-
return sum(Base.splat(_dot), zip(ws, vs))
230-
else
231-
return _dot(vs, ws)
232-
end
233-
end, xs...)
234-
end
224+
_, pbf = value_and_pullback_function(ab, f, xs...)
225+
return pbf
235226
end
236227
function value_and_pullback_function(
237228
ab::AbstractBackend,
238229
f,
239230
xs...,
240231
)
241-
return (ws) -> begin
242-
local value
243-
primalcalled = false
244-
if ab isa AbstractFiniteDifference
245-
value = primal_value(ab, nothing, f, xs)
246-
primalcalled = true
247-
end
248-
if ws === nothing
249-
vs = f(xs...)
250-
if !primalcalled
251-
value = primal_value(lowest(ab), vs, f, xs)
252-
primalcalled = true
253-
end
254-
return value, nothing
255-
end
256-
pb = pullback_function(lowest(ab), (_xs...,) -> begin
232+
value = f(xs...)
233+
function pullback_function(ws)
234+
function pullback_gradient_function(_xs...)
257235
vs = f(_xs...)
258-
if !primalcalled
259-
value = primal_value(lowest(ab), vs, f, xs)
260-
primalcalled = true
236+
if ws isa Tuple
237+
@assert length(vs) == length(ws)
238+
return sum(Base.splat(_dot), zip(ws, vs))
239+
else
240+
return _dot(vs, ws)
261241
end
262-
return vs
263-
end, xs...)(ws)
264-
return value, pb
242+
end
243+
return gradient(lowest(ab), pullback_gradient_function, xs...)
265244
end
245+
return value, pullback_function
266246
end
267247

268248
struct LazyDerivative{B, F, X}
@@ -494,6 +474,12 @@ macro primitive(expr)
494474
name = fdef[:name]
495475
if name == :pushforward_function
496476
return define_pushforward_function_and_friends(fdef) |> esc
477+
elseif name == :value_and_pullback_function
478+
return define_value_and_pullback_function_and_friends(fdef) |> esc
479+
elseif name == :jacobian
480+
return define_jacobian_and_friends(fdef) |> esc
481+
elseif name == :primal_value
482+
return define_primal_value(fdef) |> esc
497483
elseif name == :pullback_function
498484
return define_pullback_function_and_friends(fdef) |> esc
499485
else
@@ -537,30 +523,29 @@ function define_pushforward_function_and_friends(fdef)
537523
return funcs
538524
end
539525

540-
function define_pullback_function_and_friends(fdef)
541-
fdef[:name] = :($(AbstractDifferentiation).pullback_function)
526+
function define_value_and_pullback_function_and_friends(fdef)
527+
fdef[:name] = :($(AbstractDifferentiation).value_and_pullback_function)
542528
args = fdef[:args]
543529
funcs = quote
544530
$(ExprTools.combinedef(fdef))
545531
function $(AbstractDifferentiation).jacobian($(args...),)
546-
value_and_pbf = $(value_and_pullback_function)($(args...),)
547-
value, _ = value_and_pbf(nothing)
532+
value, pbf = $(value_and_pullback_function)($(args...),)
548533
identity_like = $(identity_matrix_like)(value)
549534
if eltype(identity_like) <: Tuple{Vararg{AbstractMatrix}}
550535
return map(identity_like) do identity_like_i
551536
return mapreduce(vcat, $(_eachcol).(identity_like_i)...) do (cols...)
552-
value_and_pbf(cols)[2]'
537+
pbf(cols)'
553538
end
554539
end
555540
elseif eltype(identity_like) <: AbstractMatrix
556541
# needed for Hessian computation:
557542
# value is a (grad,). Then, identity_like is a (matrix,).
558543
# cols loops over columns of the matrix
559544
return vcat.(mapslices(identity_like[1], dims=1) do cols
560-
adjoint.(value_and_pbf((cols,))[2])
545+
adjoint.(pbf((cols,)))
561546
end ...)
562547
else
563-
return adjoint.(value_and_pbf(identity_like)[2])
548+
return adjoint.(pbf(identity_like))
564549
end
565550
end
566551
end

test/defaults.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,14 @@ struct FDMBackend3{A} <: AD.AbstractFiniteDifference
3333
end
3434
FDMBackend3() = FDMBackend3(central_fdm(5, 1))
3535
const fdm_backend3 = FDMBackend3()
36-
AD.@primitive function pullback_function(ab::FDMBackend3, f, xs...)
37-
return function (vs)
36+
AD.@primitive function value_and_pullback_function(ab::FDMBackend3, f, xs...)
37+
value = f(xs...)
38+
function fd3_pullback(vs)
3839
# Supports only single output
39-
if vs isa AbstractVector
40-
return FDM.j′vp(ab.alg, f, vs, xs...)
41-
else
42-
return FDM.j′vp(ab.alg, f, only(vs), xs...)
43-
end
40+
_vs = vs isa AbstractVector ? vs : only(vs)
41+
return FDM.j′vp(ab.alg, f, _vs, xs...)
4442
end
43+
return value, fd3_pullback
4544
end
4645
##
4746

@@ -90,16 +89,14 @@ AD.primal_value(::ForwardDiffBackend2, ::Any, f, xs) = ForwardDiff.value.(f(xs..
9089
## Zygote
9190
struct ZygoteBackend1 <: AD.AbstractReverseMode end
9291
const zygote_backend1 = ZygoteBackend1()
93-
AD.@primitive function pullback_function(ab::ZygoteBackend1, f, xs...)
94-
return function (vs)
95-
# Supports only single output
96-
_, back = Zygote.pullback(f, xs...)
97-
if vs isa AbstractVector
98-
back(vs)
99-
else
100-
back(only(vs))
101-
end
92+
AD.@primitive function value_and_pullback_function(ab::ZygoteBackend1, f, xs...)
93+
# Supports only single output
94+
value, back = Zygote.pullback(f, xs...)
95+
function zygote_pullback(vs)
96+
_vs = vs isa AbstractVector ? vs : only(vs)
97+
return back(_vs)
10298
end
99+
return value, zygote_pullback
103100
end
104101

105102
@testset "defaults" begin

test/ruleconfig.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using AbstractDifferentiation
2+
using ChainRulesCore
23
using Test
34
using Zygote
45

@@ -52,4 +53,16 @@ using Zygote
5253
end
5354
@test AD.jacobian(ad, f, [1, 2, 3], 3) == ([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0])
5455
end
56+
57+
# issue #57
58+
@testset "primal computation in rrule" begin
59+
function myfunc(x)
60+
@info "This should not be logged if I have an rrule"
61+
x
62+
end
63+
ChainRulesCore.rrule(::typeof(myfunc), x) = (x, (y -> (NoTangent(), y)))
64+
65+
@test_logs Zygote.gradient(myfunc, 1) # nothing is logged
66+
@test_logs AD.derivative(AD.ZygoteBackend(), myfunc, 1) # nothing is logged
67+
end
5568
end

test/test_utils.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
247247
w = rand(rng, length(fjac(xvec, yvec)))
248248
if multiple_inputs
249249
pb1 = AD.pullback_function(backend, fjac, xvec, yvec)(w)
250-
valvec, pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)(w)
250+
valvec, pbf2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)
251+
pb2 = pbf2(w)
251252

252253
if test_types
253254
@test valvec isa Vector{Float64}
@@ -263,8 +264,10 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
263264
@test yvec == yvec2
264265
end
265266

266-
valvec1, pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)(w)
267-
valvec2, pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)(w)
267+
valvec1, pbf1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)
268+
pb1 = pbf1(w)
269+
valvec2, pbf2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)
270+
pb2 = pbf2(w)
268271
if test_types
269272
@test valvec1 isa Vector{Float64}
270273
@test valvec2 isa Vector{Float64}

0 commit comments

Comments
 (0)