@@ -66,21 +66,8 @@ function value_and_gradient(ab::AbstractBackend, f, xs...)
66
66
return value, reshape .(adjoint .(jacs),size .(xs))
67
67
end
68
68
function value_and_jacobian (ab:: AbstractBackend , f, xs... )
69
- local value
70
- primalcalled = false
71
- if lowest (ab) isa AbstractFiniteDifference
72
- value = primal_value (ab, nothing , f, xs)
73
- primalcalled = true
74
- end
75
- jacs = jacobian (lowest (ab), (_xs... ,) -> begin
76
- v = f (_xs... )
77
- if ! primalcalled
78
- value = primal_value (ab, v, f, xs)
79
- primalcalled = true
80
- end
81
- return v
82
- end , xs... )
83
-
69
+ value = f (xs... )
70
+ jacs = jacobian (lowest (ab), f, xs... )
84
71
return value, jacs
85
72
end
86
73
function value_and_hessian (ab:: AbstractBackend , f, x)
@@ -89,71 +76,54 @@ function value_and_hessian(ab::AbstractBackend, f, x)
89
76
x = only (x)
90
77
end
91
78
92
- local value
93
- primalcalled = false
94
- if ab isa AbstractFiniteDifference
95
- value = primal_value (ab, nothing , f, (x,))
96
- primalcalled = true
97
- end
79
+ value = f (x)
98
80
hess = jacobian (second_lowest (ab), _x -> begin
99
- v, g = value_and_gradient (lowest (ab), f, _x)
100
- if ! primalcalled
101
- value = primal_value (ab, v, f, (x,))
102
- primalcalled = true
103
- end
81
+ g = gradient (lowest (ab), f, _x)
104
82
return g[1 ] # gradient returns a tuple
105
83
end , x)
84
+
106
85
return value, hess
107
86
end
108
87
function value_and_hessian (ab:: HigherOrderBackend , f, x)
109
88
if x isa Tuple
110
89
# only support computation of Hessian for functions with single input argument
111
90
x = only (x)
112
91
end
113
- local value
114
- primalcalled = false
92
+
93
+ value = f (x)
115
94
hess = jacobian (second_lowest (ab), (_x,) -> begin
116
- v, g = value_and_gradient (lowest (ab), f, _x)
117
- if ! primalcalled
118
- value = primal_value (ab, v, f, (x,))
119
- primalcalled = true
120
- end
95
+ g = gradient (lowest (ab), f, _x)
121
96
return g[1 ] # gradient returns a tuple
122
97
end , x)
98
+
123
99
return value, hess
124
100
end
125
101
function value_gradient_and_hessian (ab:: AbstractBackend , f, x)
126
102
if x isa Tuple
127
103
# only support computation of Hessian for functions with single input argument
128
104
x = only (x)
129
105
end
130
- local value
131
- primalcalled = false
106
+
107
+ value = f (x)
132
108
grads, hess = value_and_jacobian (second_lowest (ab), _x -> begin
133
- v, g = value_and_gradient (lowest (ab), f, _x)
134
- if ! primalcalled
135
- value = primal_value (second_lowest (ab), v, f, (x,))
136
- primalcalled = true
137
- end
109
+ g = gradient (lowest (ab), f, _x)
138
110
return g[1 ] # gradient returns a tuple
139
111
end , x)
112
+
140
113
return value, (grads,), hess
141
114
end
142
115
function value_gradient_and_hessian (ab:: HigherOrderBackend , f, x)
143
116
if x isa Tuple
144
117
# only support computation of Hessian for functions with single input argument
145
118
x = only (x)
146
119
end
147
- local value
148
- primalcalled = false
120
+
121
+ value = f (x)
149
122
grads, hess = value_and_jacobian (second_lowest (ab), _x -> begin
150
- v, g = value_and_gradient (lowest (ab), f, _x)
151
- if ! primalcalled
152
- value = primal_value (second_lowest (ab), v, f, (x,))
153
- primalcalled = true
154
- end
123
+ g = gradient (lowest (ab), f, _x)
155
124
return g[1 ] # gradient returns a tuple
156
125
end , x)
126
+
157
127
return value, (grads,), hess
158
128
end
159
129
@@ -180,26 +150,16 @@ function value_and_pushforward_function(
180
150
f,
181
151
xs... ,
182
152
)
183
- return (ds) -> begin
153
+ n = length (xs)
154
+ value = f (xs... )
155
+ pf_function = pushforward_function (lowest (ab), f, xs... )
156
+
157
+ return ds -> begin
184
158
if ! (ds isa Tuple)
185
159
ds = (ds,)
186
160
end
187
- @assert length (ds) == length (xs)
188
- local value
189
- primalcalled = false
190
- if ab isa AbstractFiniteDifference
191
- value = primal_value (ab, nothing , f, xs)
192
- primalcalled = true
193
- end
194
- pf = pushforward_function (lowest (ab), (_xs... ,) -> begin
195
- vs = f (_xs... )
196
- if ! primalcalled
197
- value = primal_value (lowest (ab), vs, f, xs)
198
- primalcalled = true
199
- end
200
- return vs
201
- end , xs... )(ds)
202
-
161
+ @assert length (ds) == n
162
+ pf = pf_function (ds)
203
163
return value, pf
204
164
end
205
165
end
@@ -476,12 +436,6 @@ macro primitive(expr)
476
436
return define_pushforward_function_and_friends (fdef) |> esc
477
437
elseif name == :value_and_pullback_function
478
438
return define_value_and_pullback_function_and_friends (fdef) |> esc
479
- elseif name == :jacobian
480
- return define_jacobian_and_friends (fdef) |> esc
481
- elseif name == :primal_value
482
- return define_primal_value (fdef) |> esc
483
- elseif name == :pullback_function
484
- return define_pullback_function_and_friends (fdef) |> esc
485
439
else
486
440
throw (" Unsupported AD primitive." )
487
441
end
0 commit comments