diff --git a/Project.toml b/Project.toml index 085f86c..20150b9 100644 --- a/Project.toml +++ b/Project.toml @@ -6,9 +6,8 @@ version = "0.5.4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] ProximalCore = "0.1" -Zygote = "0.6" julia = "1.2" diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index 743c0ea..03210cc 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -1,5 +1,7 @@ module ProximalAlgorithms +using Requires + using ProximalCore using ProximalCore: prox, prox!, gradient, gradient! @@ -8,7 +10,6 @@ const Maybe{T} = Union{T,Nothing} # various utilities -include("utilities/ad.jl") include("utilities/fb_tools.jl") include("utilities/iteration_tools.jl") @@ -98,4 +99,18 @@ include("algorithms/li_lin.jl") include("algorithms/sfista.jl") include("algorithms/panocplus.jl") +# autodiff backends + +include("autodiff/backends.jl") +function __init__() + @require Yota = "cd998857-8626-517d-b929-70ad188a48f0" begin + println("Enabling Yota AD backend (YotaFunction)") + include("autodiff/yota.jl") + end + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + println("Enabling Zygote AD backend (ZygoteFunction)") + include("autodiff/zygote.jl") + end +end + end # module diff --git a/src/autodiff/backends.jl b/src/autodiff/backends.jl new file mode 100644 index 0000000..b9d4320 --- /dev/null +++ b/src/autodiff/backends.jl @@ -0,0 +1,11 @@ +struct YotaFunction{F} + f::F +end + +(f::YotaFunction)(x) = f.f(x) + +struct ZygoteFunction{F} + f::F +end + +(f::ZygoteFunction)(x) = f.f(x) diff --git a/src/autodiff/yota.jl b/src/autodiff/yota.jl new file mode 100644 index 0000000..7b2fea5 --- /dev/null +++ b/src/autodiff/yota.jl @@ -0,0 +1,8 @@ +using ProximalCore +using .Yota: grad + +function ProximalCore.gradient!(grad_x, f::YotaFunction, x) + f_x, g = grad(f.f, x) + grad_x .= g[2] + return f_x +end diff --git a/src/autodiff/zygote.jl b/src/autodiff/zygote.jl new file mode 100644 index 0000000..c4c3262 --- /dev/null +++ b/src/autodiff/zygote.jl @@ -0,0 +1,8 @@ +using ProximalCore +using .Zygote: pullback + +function ProximalCore.gradient!(grad_x, f::ZygoteFunction, x) + f_x, pb = pullback(f.f, x) + grad_x .= pb(one(f_x))[1] + return f_x +end diff --git a/src/utilities/ad.jl b/src/utilities/ad.jl deleted file mode 100644 index f14c097..0000000 --- a/src/utilities/ad.jl +++ /dev/null @@ -1,8 +0,0 @@ -using Zygote: pullback -using ProximalCore - -function ProximalCore.gradient!(grad, f, x) - fx, pb = pullback(f, x) - grad .= pb(one(fx))[1] - return fx -end diff --git a/test/Project.toml b/test/Project.toml index 27ac831..acd48f1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,3 +7,5 @@ ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Yota = "cd998857-8626-517d-b929-70ad188a48f0" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/utilities/test_ad.jl b/test/autodiff/test_ad.jl similarity index 80% rename from test/utilities/test_ad.jl rename to test/autodiff/test_ad.jl index 8abb327..7514418 100644 --- a/test/utilities/test_ad.jl +++ b/test/autodiff/test_ad.jl @@ -2,8 +2,12 @@ using Test using LinearAlgebra using ProximalOperators: NormL1 using ProximalAlgorithms +using Yota, Zygote -@testset "Autodiff ($T)" for T in [Float32, Float64, ComplexF32, ComplexF64] +@testset "Autodiff ($T, $AD)" for (T, AD) in Iterators.product( + [Float32, Float64, ComplexF32, ComplexF64], + [ProximalAlgorithms.ZygoteFunction, ProximalAlgorithms.YotaFunction] +) R = real(T) A = T[ 1.0 -2.0 3.0 -4.0 5.0 @@ -12,7 +16,7 @@ using ProximalAlgorithms -1.0 -1.0 -1.0 1.0 3.0 ] b = T[1.0, 2.0, 3.0, 4.0] - f(x) = R(1/2) * norm(A * x - b, 2)^2 + f = AD(x -> R(1/2) * norm(A * x .- b, 2)^2) Lf = opnorm(A)^2 m, n = size(A) diff --git a/test/definitions/compose.jl b/test/definitions/compose.jl index 8e38c2a..fa1cc3f 100644 --- a/test/definitions/compose.jl +++ b/test/definitions/compose.jl @@ -14,7 +14,7 @@ end function compose_affine_gradient!(y, g::ComposeAffine, x) res = g.A * x .+ g.b - gradres, v = gradient(g.f, res) + gradres, v = ProximalCore.gradient(g.f, res) mul!(y, adjoint(g.A), gradres) return v end diff --git a/test/runtests.jl b/test/runtests.jl index 82460bc..a8160fb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,8 @@ using Test include("definitions/arraypartition.jl") include("definitions/compose.jl") -include("utilities/test_ad.jl") +include("autodiff/test_ad.jl") + include("utilities/test_iteration_tools.jl") include("utilities/test_fb_tools.jl")