Skip to content

Commit b97ff8d

Browse files
committed
Update the extensions
1 parent d1dd9a0 commit b97ff8d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ext/ADNLPModelsZygoteExt.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@ module ADNLPModelsZygoteExt
22

33
using Zygote, ADNLPModels
44

5-
function gradient(::ZygoteADGradient, f, x)
5+
function gradient(::ADNLPModels.ZygoteADGradient, f, x)
66
g = Zygote.gradient(f, x)[1]
77
return g === nothing ? zero(x) : g
88
end
9-
function gradient!(::ZygoteADGradient, g, f, x)
9+
function gradient!(::ADNLPModels.ZygoteADGradient, g, f, x)
1010
_g = Zygote.gradient(f, x)[1]
1111
g .= _g === nothing ? 0 : _g
1212
end
1313

14-
function Jprod!(::ZygoteADJprod, Jv, f, x, v, ::Val)
14+
function Jprod!(::ADNLPModels.ZygoteADJprod, Jv, f, x, v, ::Val)
1515
Jv .= vec(Zygote.jacobian(t -> f(x + t * v), 0)[1])
1616
return Jv
1717
end
1818

19-
function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val)
19+
function Jtprod!(::ADNLPModels.ZygoteADJtprod, Jtv, f, x, v, ::Val)
2020
g = Zygote.gradient(x -> dot(f(x), v), x)[1]
2121
if g === nothing
2222
Jtv .= zero(x)
@@ -26,14 +26,14 @@ function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val)
2626
return Jtv
2727
end
2828

29-
function jacobian(::ZygoteADJacobian, f, x)
29+
function jacobian(::ADNLPModels.ZygoteADJacobian, f, x)
3030
return Zygote.jacobian(f, x)[1]
3131
end
3232

33-
function hessian(b::ZygoteADHessian, f, x)
33+
function hessian(b::ADNLPModels.ZygoteADHessian, f, x)
3434
return jacobian(
35-
ForwardDiffADJacobian(length(x), f, x0 = x),
36-
x -> gradient(ZygoteADGradient(), f, x),
35+
ADNLPModels.ForwardDiffADJacobian(length(x), f, x0 = x),
36+
x -> gradient(ADNLPModels.ZygoteADGradient(), f, x),
3737
x,
3838
)
3939
end

0 commit comments

Comments
 (0)