Skip to content

Commit cb57b09

Browse files
Copilotyebai
andcommitted
Use value_and_gradient for efficiency and update dependency versions
Co-authored-by: yebai <3279477+yebai@users.noreply.github.com>
1 parent 7b6fc4d commit cb57b09

File tree

3 files changed

+6
-8
lines changed

3 files changed

+6
-8
lines changed

examples/1-mauna-loa/script.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,8 @@ function optimize_loss(loss, θ_init; optimizer=default_optimizer, maxiter=1_000
229229
backend = AutoMooncake()
230230
function fg!(F, G, x)
231231
if F !== nothing && G !== nothing
232-
val = loss_packed(x)
233-
grad = only(gradient(loss_packed, backend, x))
234-
G .= grad
232+
val, grad = value_and_gradient(loss_packed, backend, x)
233+
G .= only(grad)
235234
return val
236235
elseif G !== nothing
237236
grad = only(gradient(loss_packed, backend, x))

examples/3-parametric-heteroscedastic/script.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ end;
5454
backend = AutoMooncake()
5555
function objective_and_gradient(F, G, flat_θ)
5656
if G !== nothing
57-
val = objective(flat_θ)
58-
grad = only(gradient(objective, backend, flat_θ))
59-
copyto!(G, grad)
57+
val, grad = value_and_gradient(objective, backend, flat_θ)
58+
copyto!(G, only(grad))
6059
if F !== nothing
6160
return val
6261
end

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1717

1818
[compat]
1919
Aqua = "0.8"
20-
DifferentiationInterface = "0.5, 0.6"
20+
DifferentiationInterface = "0.7"
2121
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
2222
Documenter = "1"
2323
FillArrays = "0.11, 0.12, 0.13, 1"
2424
FiniteDifferences = "0.9.6, 0.10, 0.11, 0.12"
2525
LinearAlgebra = "1"
26-
Mooncake = "0.3, 0.4, 0.5"
26+
Mooncake = "0.4"
2727
PDMats = "0.11"
2828
Pkg = "1"
2929
Plots = "1"

0 commit comments

Comments
 (0)