Skip to content

Commit 9e7e07b

Browse files
authored
Simpler defaults without FiniteDifferences special cases (#96)
* Simpler defaults without FiniteDifferences special cases * Remove support for macros that were reintroduced unintentionally
1 parent 3c18e86 commit 9e7e07b

File tree

1 file changed

+24
-70
lines changed

1 file changed

+24
-70
lines changed

src/AbstractDifferentiation.jl

Lines changed: 24 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,8 @@ function value_and_gradient(ab::AbstractBackend, f, xs...)
6666
return value, reshape.(adjoint.(jacs),size.(xs))
6767
end
6868
function value_and_jacobian(ab::AbstractBackend, f, xs...)
69-
local value
70-
primalcalled = false
71-
if lowest(ab) isa AbstractFiniteDifference
72-
value = primal_value(ab, nothing, f, xs)
73-
primalcalled = true
74-
end
75-
jacs = jacobian(lowest(ab), (_xs...,) -> begin
76-
v = f(_xs...)
77-
if !primalcalled
78-
value = primal_value(ab, v, f, xs)
79-
primalcalled = true
80-
end
81-
return v
82-
end, xs...)
83-
69+
value = f(xs...)
70+
jacs = jacobian(lowest(ab), f, xs...)
8471
return value, jacs
8572
end
8673
function value_and_hessian(ab::AbstractBackend, f, x)
@@ -89,71 +76,54 @@ function value_and_hessian(ab::AbstractBackend, f, x)
8976
x = only(x)
9077
end
9178

92-
local value
93-
primalcalled = false
94-
if ab isa AbstractFiniteDifference
95-
value = primal_value(ab, nothing, f, (x,))
96-
primalcalled = true
97-
end
79+
value = f(x)
9880
hess = jacobian(second_lowest(ab), _x -> begin
99-
v, g = value_and_gradient(lowest(ab), f, _x)
100-
if !primalcalled
101-
value = primal_value(ab, v, f, (x,))
102-
primalcalled = true
103-
end
81+
g = gradient(lowest(ab), f, _x)
10482
return g[1] # gradient returns a tuple
10583
end, x)
84+
10685
return value, hess
10786
end
10887
function value_and_hessian(ab::HigherOrderBackend, f, x)
10988
if x isa Tuple
11089
# only support computation of Hessian for functions with single input argument
11190
x = only(x)
11291
end
113-
local value
114-
primalcalled = false
92+
93+
value = f(x)
11594
hess = jacobian(second_lowest(ab), (_x,) -> begin
116-
v, g = value_and_gradient(lowest(ab), f, _x)
117-
if !primalcalled
118-
value = primal_value(ab, v, f, (x,))
119-
primalcalled = true
120-
end
95+
g = gradient(lowest(ab), f, _x)
12196
return g[1] # gradient returns a tuple
12297
end, x)
98+
12399
return value, hess
124100
end
125101
function value_gradient_and_hessian(ab::AbstractBackend, f, x)
126102
if x isa Tuple
127103
# only support computation of Hessian for functions with single input argument
128104
x = only(x)
129105
end
130-
local value
131-
primalcalled = false
106+
107+
value = f(x)
132108
grads, hess = value_and_jacobian(second_lowest(ab), _x -> begin
133-
v, g = value_and_gradient(lowest(ab), f, _x)
134-
if !primalcalled
135-
value = primal_value(second_lowest(ab), v, f, (x,))
136-
primalcalled = true
137-
end
109+
g = gradient(lowest(ab), f, _x)
138110
return g[1] # gradient returns a tuple
139111
end, x)
112+
140113
return value, (grads,), hess
141114
end
142115
function value_gradient_and_hessian(ab::HigherOrderBackend, f, x)
143116
if x isa Tuple
144117
# only support computation of Hessian for functions with single input argument
145118
x = only(x)
146119
end
147-
local value
148-
primalcalled = false
120+
121+
value = f(x)
149122
grads, hess = value_and_jacobian(second_lowest(ab), _x -> begin
150-
v, g = value_and_gradient(lowest(ab), f, _x)
151-
if !primalcalled
152-
value = primal_value(second_lowest(ab), v, f, (x,))
153-
primalcalled = true
154-
end
123+
g = gradient(lowest(ab), f, _x)
155124
return g[1] # gradient returns a tuple
156125
end, x)
126+
157127
return value, (grads,), hess
158128
end
159129

@@ -180,26 +150,16 @@ function value_and_pushforward_function(
180150
f,
181151
xs...,
182152
)
183-
return (ds) -> begin
153+
n = length(xs)
154+
value = f(xs...)
155+
pf_function = pushforward_function(lowest(ab), f, xs...)
156+
157+
return ds -> begin
184158
if !(ds isa Tuple)
185159
ds = (ds,)
186160
end
187-
@assert length(ds) == length(xs)
188-
local value
189-
primalcalled = false
190-
if ab isa AbstractFiniteDifference
191-
value = primal_value(ab, nothing, f, xs)
192-
primalcalled = true
193-
end
194-
pf = pushforward_function(lowest(ab), (_xs...,) -> begin
195-
vs = f(_xs...)
196-
if !primalcalled
197-
value = primal_value(lowest(ab), vs, f, xs)
198-
primalcalled = true
199-
end
200-
return vs
201-
end, xs...)(ds)
202-
161+
@assert length(ds) == n
162+
pf = pf_function(ds)
203163
return value, pf
204164
end
205165
end
@@ -476,12 +436,6 @@ macro primitive(expr)
476436
return define_pushforward_function_and_friends(fdef) |> esc
477437
elseif name == :value_and_pullback_function
478438
return define_value_and_pullback_function_and_friends(fdef) |> esc
479-
elseif name == :jacobian
480-
return define_jacobian_and_friends(fdef) |> esc
481-
elseif name == :primal_value
482-
return define_primal_value(fdef) |> esc
483-
elseif name == :pullback_function
484-
return define_pullback_function_and_friends(fdef) |> esc
485439
else
486440
throw("Unsupported AD primitive.")
487441
end

0 commit comments

Comments
 (0)