Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
name = "ProximalAlgorithms"
uuid = "140ffc9f-1907-541a-a177-7475e0a401e9"
version = "0.5.5"
version = "0.6.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractDifferentiation = "0.6"
LinearAlgebra = "1.2"
Printf = "1.2"
ProximalCore = "0.1"
Zygote = "0.6"
julia = "1.2"
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
Expand All @@ -7,6 +8,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Documenter = "1"
Expand Down
7 changes: 5 additions & 2 deletions docs/src/examples/sparse_linear_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ end

mean_squared_error(label, output) = mean((output .- label) .^ 2) / 2

using Zygote
using AbstractDifferentiation: ZygoteBackend
using ProximalAlgorithms

training_loss = ProximalAlgorithms.ZygoteFunction(
wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input))
training_loss = ProximalAlgorithms.AutoDifferentiable(
wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input)),
ZygoteBackend()
)

# As regularization we will use the L1 norm, implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl):
Expand Down
44 changes: 27 additions & 17 deletions docs/src/guide/custom_objectives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,32 @@
#
# Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref).
#
# To compute gradients, ProximalAlgorithms provides a fallback definition for [`ProximalCore.gradient!`](@ref),
# relying on [Zygote](https://github.com/FluxML/Zygote.jl) to use automatic differentiation.
# Therefore, you can provide any (differentiable) Julia function wherever gradients need to be taken,
# and everything will work out of the box.
# To compute gradients, algorithms use [`ProximalAlgorithms.value_and_pullback`](@ref):
# this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation
# with any of its supported backends, when functions are wrapped in [`ProximalAlgorithms.AutoDifferentiable`](@ref),
# as the esamples below show.
#
# If however one would like to provide their own gradient implementation (e.g. for efficiency reasons),
# they can simply implement a method for [`ProximalCore.gradient!`](@ref).
# If however you would like to provide your own gradient implementation (e.g. for efficiency reasons),
# you can simply implement a method for [`ProximalAlgorithms.value_and_pullback`](@ref) on your own function type.
#
# ```@docs
# ProximalCore.prox
# ProximalCore.prox!
# ProximalCore.gradient
# ProximalCore.gradient!
# ProximalAlgorithms.value_and_pullback
# ProximalAlgorithms.AutoDifferentiable
# ```
#
# ## Example: constrained Rosenbrock
#
# Let's try to minimize the celebrated Rosenbrock function, but constrained to the unit norm ball. The cost function is

using Zygote
using AbstractDifferentiation: ZygoteBackend
using ProximalAlgorithms

rosenbrock2D = ProximalAlgorithms.ZygoteFunction(
x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2
rosenbrock2D = ProximalAlgorithms.AutoDifferentiable(
x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2,
ZygoteBackend()
)

# To enforce the constraint, we define the indicator of the unit ball, together with its proximal mapping:
Expand Down Expand Up @@ -82,17 +85,23 @@ scatter!([solution[1]], [solution[2]], color=:red, markershape=:star5, label="co

mutable struct Counting{T}
f::T
eval_count::Int
gradient_count::Int
prox_count::Int
end

Counting(f::T) where T = Counting{T}(f, 0, 0)
Counting(f::T) where T = Counting{T}(f, 0, 0, 0)

# Now we only need to intercept any call to `gradient!` and `prox!` and increase counters there:
# Now we only need to intercept any call to `value_and_pullback` and `prox!` and increase counters there:

function ProximalCore.gradient!(y, f::Counting, x)
f.gradient_count += 1
return ProximalCore.gradient!(y, f.f, x)
function ProximalAlgorithms.value_and_pullback(f::Counting, x)
f.eval_count += 1
fx, pb = ProximalAlgorithms.value_and_pullback(f.f, x)
function counting_pullback()
f.gradient_count += 1
return pb()
end
return fx, counting_pullback
end

function ProximalCore.prox!(y, f::Counting, x, gamma)
Expand All @@ -109,5 +118,6 @@ solution, iterations = panoc(x0=-ones(2), f=f, g=g)

# and check how many operations where actually performed:

println(f.gradient_count)
println(g.prox_count)
println("function evals: $(f.eval_count)")
println("gradient evals: $(f.gradient_count)")
println(" prox evals: $(g.prox_count)")
12 changes: 8 additions & 4 deletions docs/src/guide/getting_started.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# The literature on proximal operators and algorithms is vast: for an overview, one can refer to [Parikh2014](@cite), [Beck2017](@cite).
#
# To evaluate these first-order primitives, in ProximalAlgorithms:
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [Zygote](https://github.com/FluxML/Zygote.jl)).
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl) and all of its backends).
# * ``\operatorname{prox}_{f_i}`` relies on the intereface of [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15).
# Both of the above can be implemented for custom function types, as [documented here](@ref custom_terms).
#
Expand Down Expand Up @@ -51,11 +51,14 @@
# which we will solve using the fast proximal gradient method (also known as fast forward-backward splitting):

using LinearAlgebra
using Zygote
using AbstractDifferentiation: ZygoteBackend
using ProximalOperators
using ProximalAlgorithms

quadratic_cost = ProximalAlgorithms.ZygoteFunction(
x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x)
quadratic_cost = ProximalAlgorithms.AutoDifferentiable(
x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x),
ZygoteBackend()
)
box_indicator = ProximalOperators.IndBox(0, 1)

Expand All @@ -70,7 +73,8 @@ solution, iterations = ffb(x0=ones(2), f=quadratic_cost, g=box_indicator)

# We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards:

-ProximalAlgorithms.gradient(quadratic_cost, solution)[1]
v, pb = ProximalAlgorithms.value_and_pullback(quadratic_cost, solution)
-pb()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explain that the pullback outputs the gradient?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed


# Or by plotting the solution against the cost function and constraint:

Expand Down
14 changes: 14 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
julia:
julia --project=.

instantiate:
julia --project=. -e 'using Pkg; Pkg.instantiate()'

test:
julia --project=. -e 'using Pkg; Pkg.test()'

format:
julia --project=. -e 'using JuliaFormatter: format; format(".")'

docs:
julia --project=./docs docs/make.jl
36 changes: 34 additions & 2 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,46 @@
module ProximalAlgorithms

using AbstractDifferentiation
using ProximalCore
using ProximalCore: prox, prox!, gradient, gradient!
using ProximalCore: prox, prox!

const RealOrComplex{R} = Union{R,Complex{R}}
const Maybe{T} = Union{T,Nothing}

"""
AutoDifferentiable(f, backend)
Wrap function `f` to be auto-differentiated using `backend`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify that it's a callable struct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

The backend can be any from [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl).
"""
struct AutoDifferentiable{F, B}
f::F
backend::B
end

(f::AutoDifferentiable)(x) = f.f(x)

"""
value_and_pullback(f, x)
Return a tuple containing the value of `f` at `x`, and the pullback function `pb`.
Function `pb`, once called, yields the gradient of `f` at `x`.
"""
value_and_pullback

function value_and_pullback(f::AutoDifferentiable, x)
fx, pb = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x)
return fx, () -> pb(one(fx))[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Striclty speaking this is not a pullback if it takes no input. The point of a pullback is to take a cotangent and pull it back into the input space.
In this case, the cotangent is a scalar, and taking this scalar to be 1 returns the gradient, so it's okay, but the terminology might confuse autodiff people

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah right, I was unsure here, thanks for pointing that out. I guess I could call it simply “closure”

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value_and_gradient_closure other something like that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

end

function value_and_pullback(f::ProximalCore.Zero, x)
f(x), () -> zero(x)
end

# various utilities

include("utilities/ad.jl")
include("utilities/fb_tools.jl")
include("utilities/iteration_tools.jl")

Expand Down
6 changes: 4 additions & 2 deletions src/algorithms/davis_yin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ end
function Base.iterate(iter::DavisYinIteration)
z = copy(iter.x0)
xg, = prox(iter.g, z, iter.gamma)
grad_f_xg, = gradient(iter.f, xg)
f_xg, pb = value_and_pullback(iter.f, xg)
grad_f_xg = pb()
z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg
xh, = prox(iter.h, z_half, iter.gamma)
res = xh - xg
Expand All @@ -66,7 +67,8 @@ end

function Base.iterate(iter::DavisYinIteration, state::DavisYinState)
prox!(state.xg, iter.g, state.z, iter.gamma)
gradient!(state.grad_f_xg, iter.f, state.xg)
f_xg, pb = value_and_pullback(iter.f, state.xg)
state.grad_f_xg .= pb()
state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg
prox!(state.xh, iter.h, state.z_half, iter.gamma)
state.res .= state.xh .- state.xg
Expand Down
6 changes: 4 additions & 2 deletions src/algorithms/fast_forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ end

function Base.iterate(iter::FastForwardBackwardIteration)
x = copy(iter.x0)
grad_f_x, f_x = gradient(iter.f, x)
f_x, pb = value_and_pullback(iter.f, x)
grad_f_x = pb()
gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma
y = x - gamma .* grad_f_x
z, g_z = prox(iter.g, y, gamma)
Expand Down Expand Up @@ -103,7 +104,8 @@ function Base.iterate(iter::FastForwardBackwardIteration{R}, state::FastForwardB
state.x .= state.z .+ beta .* (state.z .- state.z_prev)
state.z_prev, state.z = state.z, state.z_prev

state.f_x = gradient!(state.grad_f_x, iter.f, state.x)
state.f_x, pb = value_and_pullback(iter.f, state.x)
state.grad_f_x .= pb()
state.y .= state.x .- state.gamma .* state.grad_f_x
state.g_z = prox!(state.z, iter.g, state.y, state.gamma)
state.res .= state.x .- state.z
Expand Down
6 changes: 4 additions & 2 deletions src/algorithms/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ end

function Base.iterate(iter::ForwardBackwardIteration)
x = copy(iter.x0)
grad_f_x, f_x = gradient(iter.f, x)
f_x, pb = value_and_pullback(iter.f, x)
grad_f_x = pb()
gamma = iter.gamma === nothing ? 1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma
y = x - gamma .* grad_f_x
z, g_z = prox(iter.g, y, gamma)
Expand All @@ -81,7 +82,8 @@ function Base.iterate(iter::ForwardBackwardIteration{R}, state::ForwardBackwardS
state.grad_f_x, state.grad_f_z = state.grad_f_z, state.grad_f_x
else
state.x, state.z = state.z, state.x
state.f_x = gradient!(state.grad_f_x, iter.f, state.x)
state.f_x, pb = value_and_pullback(iter.f, state.x)
state.grad_f_x .= pb()
end

state.y .= state.x .- state.gamma .* state.grad_f_x
Expand Down
9 changes: 6 additions & 3 deletions src/algorithms/li_lin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ end

function Base.iterate(iter::LiLinIteration{R}) where {R}
y = copy(iter.x0)
grad_f_y, f_y = gradient(iter.f, y)
f_y, pb = value_and_pullback(iter.f, y)
grad_f_y = pb()

# TODO: initialize gamma if not provided
# TODO: authors suggest Barzilai-Borwein rule?
Expand Down Expand Up @@ -102,7 +103,8 @@ function Base.iterate(
else
# TODO: re-use available space in state?
# TODO: backtrack gamma at x
grad_f_x, f_x = gradient(iter.f, x)
f_x, pb = value_and_pullback(iter.f, x)
grad_f_x = pb()
x_forward = state.x - state.gamma .* grad_f_x
v, g_v = prox(iter.g, x_forward, state.gamma)
Fv = iter.f(v) + g_v
Expand All @@ -121,7 +123,8 @@ function Base.iterate(
Fx = Fv
end

state.f_y = gradient!(state.grad_f_y, iter.f, state.y)
state.f_y, pb = value_and_pullback(iter.f, state.y)
state.grad_f_y .= pb()
state.y_forward .= state.y .- state.gamma .* state.grad_f_y
state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma)

Expand Down
12 changes: 8 additions & 4 deletions src/algorithms/panoc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ f_model(iter::PANOCIteration, state::PANOCState) = f_model(state.f_Ax, state.At_
function Base.iterate(iter::PANOCIteration{R}) where R
x = copy(iter.x0)
Ax = iter.A * x
grad_f_Ax, f_Ax = gradient(iter.f, Ax)
f_Ax, pb = value_and_pullback(iter.f, Ax)
grad_f_Ax = pb()
gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : iter.gamma
At_grad_f_Ax = iter.A' * grad_f_Ax
y = x - gamma .* At_grad_f_Ax
Expand Down Expand Up @@ -152,7 +153,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where

state.x_d .= state.x .+ state.d
state.Ax_d .= state.Ax .+ state.Ad
state.f_Ax_d = gradient!(state.grad_f_Ax_d, iter.f, state.Ax_d)
state.f_Ax_d, pb = value_and_pullback(iter.f, state.Ax_d)
state.grad_f_Ax_d .= pb()
mul!(state.At_grad_f_Ax_d, adjoint(iter.A), state.grad_f_Ax_d)

copyto!(state.x, state.x_d)
Expand Down Expand Up @@ -189,7 +191,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where
# along a line using interpolation and linear combinations
# this allows saving operations
if isinf(f_Az)
f_Az = gradient!(state.grad_f_Az, iter.f, state.Az)
f_Az, pb = value_and_pullback(iter.f, state.Az)
state.grad_f_Az .= pb()
end
if isinf(c)
mul!(state.At_grad_f_Az, iter.A', state.grad_f_Az)
Expand All @@ -203,7 +206,8 @@ function Base.iterate(iter::PANOCIteration{R, Tx, Tf}, state::PANOCState) where
else
# otherwise, in the general case where f is only smooth, we compute
# one gradient and matvec per backtracking step
state.f_Ax = gradient!(state.grad_f_Ax, iter.f, state.Ax)
state.f_Ax, pb = value_and_pullback(iter.f, state.Ax)
state.grad_f_Ax .= pb()
mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax)
end

Expand Down
Loading