-
Notifications
You must be signed in to change notification settings - Fork 24
Update gradient interface, support AbstractDifferentiation #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
ea589ab
7a32038
4a2bbb8
9c15b0e
585e244
7cf1c12
fc700bb
05f8334
c56cc08
ea5be88
6b94ce8
59686f4
aff22b0
7cdcc2e
b402024
e79ec2b
3f93199
4bf6f75
d407660
d046528
2e062c3
9ac83c0
2d8d611
fd98409
68a80a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,16 @@ | ||
| name = "ProximalAlgorithms" | ||
| uuid = "140ffc9f-1907-541a-a177-7475e0a401e9" | ||
| version = "0.5.5" | ||
| version = "0.6.0" | ||
|
|
||
| [deps] | ||
| AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" | ||
| LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
| Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" | ||
| ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" | ||
| Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
|
||
| [compat] | ||
| AbstractDifferentiation = "0.6" | ||
| LinearAlgebra = "1.2" | ||
| Printf = "1.2" | ||
| ProximalCore = "0.1" | ||
| Zygote = "0.6" | ||
| julia = "1.2" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ | |
| # The literature on proximal operators and algorithms is vast: for an overview, one can refer to [Parikh2014](@cite), [Beck2017](@cite). | ||
| # | ||
| # To evaluate these first-order primitives, in ProximalAlgorithms: | ||
| # * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [Zygote](https://github.com/FluxML/Zygote.jl)). | ||
| # * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl) and all of its backends). | ||
| # * ``\operatorname{prox}_{f_i}`` relies on the intereface of [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15). | ||
| # Both of the above can be implemented for custom function types, as [documented here](@ref custom_terms). | ||
| # | ||
|
|
@@ -51,11 +51,14 @@ | |
| # which we will solve using the fast proximal gradient method (also known as fast forward-backward splitting): | ||
|
|
||
| using LinearAlgebra | ||
| using Zygote | ||
| using AbstractDifferentiation: ZygoteBackend | ||
| using ProximalOperators | ||
| using ProximalAlgorithms | ||
|
|
||
| quadratic_cost = ProximalAlgorithms.ZygoteFunction( | ||
| x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x) | ||
| quadratic_cost = ProximalAlgorithms.AutoDifferentiable( | ||
| x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x), | ||
| ZygoteBackend() | ||
| ) | ||
| box_indicator = ProximalOperators.IndBox(0, 1) | ||
|
|
||
|
|
@@ -70,7 +73,8 @@ solution, iterations = ffb(x0=ones(2), f=quadratic_cost, g=box_indicator) | |
|
|
||
| # We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards: | ||
|
|
||
| -ProximalAlgorithms.gradient(quadratic_cost, solution)[1] | ||
| v, pb = ProximalAlgorithms.value_and_pullback(quadratic_cost, solution) | ||
| -pb() | ||
|
||
|
|
||
| # Or by plotting the solution against the cost function and constraint: | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| julia: | ||
| julia --project=. | ||
|
|
||
| instantiate: | ||
| julia --project=. -e 'using Pkg; Pkg.instantiate()' | ||
|
|
||
| test: | ||
| julia --project=. -e 'using Pkg; Pkg.test()' | ||
|
|
||
| format: | ||
| julia --project=. -e 'using JuliaFormatter: format; format(".")' | ||
|
|
||
| docs: | ||
| julia --project=./docs docs/make.jl |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,46 @@ | ||
| module ProximalAlgorithms | ||
|
|
||
| using AbstractDifferentiation | ||
| using ProximalCore | ||
| using ProximalCore: prox, prox!, gradient, gradient! | ||
| using ProximalCore: prox, prox! | ||
|
|
||
| const RealOrComplex{R} = Union{R,Complex{R}} | ||
| const Maybe{T} = Union{T,Nothing} | ||
|
|
||
| """ | ||
| AutoDifferentiable(f, backend) | ||
| Wrap function `f` to be auto-differentiated using `backend`. | ||
|
||
| The backend can be any from [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl). | ||
| """ | ||
| struct AutoDifferentiable{F, B} | ||
| f::F | ||
| backend::B | ||
| end | ||
|
|
||
| (f::AutoDifferentiable)(x) = f.f(x) | ||
|
|
||
| """ | ||
| value_and_pullback(f, x) | ||
| Return a tuple containing the value of `f` at `x`, and the pullback function `pb`. | ||
| Function `pb`, once called, yields the gradient of `f` at `x`. | ||
| """ | ||
| value_and_pullback | ||
|
|
||
| function value_and_pullback(f::AutoDifferentiable, x) | ||
| fx, pb = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x) | ||
| return fx, () -> pb(one(fx))[1] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Striclty speaking this is not a pullback if it takes no input. The point of a pullback is to take a cotangent and pull it back into the input space.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah right, I was unsure here, thanks for pointing that out. I guess I could call it simply “closure”
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed |
||
| end | ||
|
|
||
| function value_and_pullback(f::ProximalCore.Zero, x) | ||
| f(x), () -> zero(x) | ||
| end | ||
|
|
||
| # various utilities | ||
|
|
||
| include("utilities/ad.jl") | ||
| include("utilities/fb_tools.jl") | ||
| include("utilities/iteration_tools.jl") | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.