Skip to content

Commit da2c944

Browse files
authored
Allow generic arrays, not just vectors (#167)
1 parent 853a544 commit da2c944

File tree

12 files changed

+146
-90
lines changed

12 files changed

+146
-90
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
33
authors = ["Guillaume Dalle", "Mohamed Tarek"]
4-
version = "0.7.1"
4+
version = "0.7.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/faq.md

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,11 @@ However, this can be switched to any other "inner" backend compatible with [Diff
1515

1616
## Input and output types
1717

18-
### Vectors
19-
20-
Functions that eat or spit out arbitrary vectors are supported, as long as the forward mapping _and_ conditions return vectors of the same size.
21-
22-
If you deal with small vectors (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.
23-
2418
### Arrays
2519

26-
Functions that eat or spit out matrices and higher-order tensors are not supported.
27-
You can use `vec` and `reshape` for the conversion to and from vectors.
20+
Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size.
21+
22+
If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.
2823

2924
### Scalars
3025

examples/0_intro.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ This is essentially the componentwise square root function but with an additiona
2929
We can check that it does what it's supposed to do.
3030
=#
3131

32-
x = [4.0, 9.0]
32+
x = [1.0 2.0; 3.0 4.0]
3333
badsqrt(x)
3434
@test badsqrt(x) sqrt.(x) #src
3535

3636
#=
3737
Of course the Jacobian has an explicit formula.
3838
=#
3939

40-
J = Diagonal(0.5 ./ sqrt.(x))
40+
J = Diagonal(0.5 ./ vec(sqrt.(x)))
4141

4242
#=
4343
However, things start to go wrong when we compute it with autodiff, due to the [limitations of ForwardDiff.jl](https://juliadiff.org/ForwardDiff.jl/stable/user/limitations/) and [those of Zygote.jl](https://fluxml.ai/Zygote.jl/stable/limitations/).

ext/ImplicitDifferentiationChainRulesCoreExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using ImplicitDifferentiation:
1414
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)
1515

1616
function ChainRulesCore.rrule(
17-
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractVector, args::Vararg{Any,N};
17+
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
1818
) where {N}
1919
y, z = implicit(x, args...)
2020

@@ -25,8 +25,10 @@ function ChainRulesCore.rrule(
2525

2626
function implicit_pullback((dy, dz))
2727
dy = unthunk(dy)
28-
dc = implicit.linear_solver(Aᵀ, -dy)
29-
dx = Bᵀ * dc
28+
dy_vec = vec(dy)
29+
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
30+
dx_vec = Bᵀ * dc_vec
31+
dx = reshape(dx_vec, size(x))
3032
df = NoTangent()
3133
dargs = ntuple(unimplemented_tangent, N)
3234
return (df, project_x(dx), dargs...)

ext/ImplicitDifferentiationForwardDiffExt.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using ForwardDiff: Dual, Partials, partials, value
55
using ImplicitDifferentiation: ImplicitFunction, build_A, build_B
66

77
function (implicit::ImplicitFunction)(
8-
x_and_dx::AbstractVector{Dual{T,R,N}}, args...
8+
x_and_dx::AbstractArray{Dual{T,R,N}}, args...
99
) where {T,R,N}
1010
x = value.(x_and_dx)
1111
y, z = implicit(x, args...)
@@ -14,16 +14,17 @@ function (implicit::ImplicitFunction)(
1414
A = build_A(implicit, x, y, z, args...; suggested_backend)
1515
B = build_B(implicit, x, y, z, args...; suggested_backend)
1616

17-
dX = map(1:N) do k
17+
dX = ntuple(Val(N)) do k
1818
partials.(x_and_dx, k)
1919
end
20-
dC = mapreduce(hcat, dX) do dₖx
21-
B * dₖx
20+
dC_mat = mapreduce(hcat, dX) do dₖx
21+
dₖx_vec = vec(dₖx)
22+
dₖc_vec = B * dₖx_vec
2223
end
23-
dY = implicit.linear_solver(A, -dC)
24+
dY_mat = implicit.linear_solver(A, -dC_mat)
2425

25-
y_and_dy = map(eachindex(y)) do i
26-
Dual{T}(y[i], Partials(ntuple(k -> dY[i, k], Val(N))))
26+
y_and_dy = map(LinearIndices(y)) do i
27+
Dual{T}(y[i], Partials(ntuple(k -> dY_mat[i, k], Val(N))))
2728
end
2829

2930
return y_and_dy, z

src/ImplicitDifferentiation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ using Krylov: gmres
2323
using LinearOperators: LinearOperator
2424
using LinearAlgebra: factorize
2525

26+
include("utils.jl")
2627
include("settings.jl")
2728
include("preparation.jl")
2829
include("implicit_function.jl")

src/execution.jl

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,40 @@
11
const SYMMETRIC = false
22
const HERMITIAN = false
33

4-
struct JVP!{F,P,B,X,C}
4+
struct JVP!{F,P,B,I,C}
55
f::F
66
prep::P
77
backend::B
8-
x::X
8+
input::I
99
contexts::C
1010
end
1111

12-
struct VJP!{F,P,B,X,C}
12+
struct VJP!{F,P,B,I,C}
1313
f::F
1414
prep::P
1515
backend::B
16-
x::X
16+
input::I
1717
contexts::C
1818
end
1919

2020
function (po::JVP!)(res::AbstractVector, v::AbstractVector)
21-
(; f, backend, x, contexts, prep) = po
22-
pushforward!(f, (res,), prep, backend, x, (v,), contexts...)
21+
(; f, backend, input, contexts, prep) = po
22+
pushforward!(f, (res,), prep, backend, input, (v,), contexts...)
2323
return res
2424
end
2525

2626
function (po::VJP!)(res::AbstractVector, v::AbstractVector)
27-
(; f, backend, x, contexts, prep) = po
28-
pullback!(f, (res,), prep, backend, x, (v,), contexts...)
27+
(; f, backend, input, contexts, prep) = po
28+
pullback!(f, (res,), prep, backend, input, (v,), contexts...)
2929
return res
3030
end
3131

3232
## A
3333

3434
function build_A(
3535
implicit::ImplicitFunction,
36-
x::AbstractVector,
37-
y::AbstractVector,
36+
x::AbstractArray,
37+
y::AbstractArray,
3838
z,
3939
args...;
4040
suggested_backend::AbstractADType,
@@ -58,21 +58,24 @@ function build_A_aux(
5858
(; conditions, backend, prep_A) = implicit
5959
actual_backend = isnothing(backend) ? suggested_backend : backend
6060
contexts = (Constant(x), Constant(z), map(Constant, args)...)
61+
f_vec = VecToVec(Switch12(conditions), y)
62+
y_vec = vec(y)
63+
dy_vec = vec(zero(y))
6164
prep_A_same = prepare_pushforward_same_point(
62-
Switch12(conditions), prep_A..., actual_backend, y, (zero(y),), contexts...
65+
f_vec, prep_A..., actual_backend, y_vec, (dy_vec,), contexts...
6366
)
64-
prod! = JVP!(Switch12(conditions), prep_A_same, actual_backend, y, contexts)
67+
prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, contexts)
6568
return LinearOperator(
66-
eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y)
69+
eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y_vec)
6770
)
6871
end
6972

7073
## Aᵀ
7174

7275
function build_Aᵀ(
7376
implicit::ImplicitFunction,
74-
x::AbstractVector,
75-
y::AbstractVector,
77+
x::AbstractArray,
78+
y::AbstractArray,
7679
z,
7780
args...;
7881
suggested_backend::AbstractADType,
@@ -98,21 +101,24 @@ function build_Aᵀ_aux(
98101
(; conditions, backend, prep_Aᵀ) = implicit
99102
actual_backend = isnothing(backend) ? suggested_backend : backend
100103
contexts = (Constant(x), Constant(z), map(Constant, args)...)
104+
f_vec = VecToVec(Switch12(conditions), y)
105+
y_vec = vec(y)
106+
dc_vec = vec(zero(y))
101107
prep_Aᵀ_same = prepare_pullback_same_point(
102-
Switch12(conditions), prep_Aᵀ..., actual_backend, y, (zero(y),), contexts...
108+
f_vec, prep_Aᵀ..., actual_backend, y_vec, (dc_vec,), contexts...
103109
)
104-
prod! = VJP!(Switch12(conditions), prep_Aᵀ_same, actual_backend, y, contexts)
110+
prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, contexts)
105111
return LinearOperator(
106-
eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y)
112+
eltype(y), length(y), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(y_vec)
107113
)
108114
end
109115

110116
## B
111117

112118
function build_B(
113119
implicit::ImplicitFunction,
114-
x::AbstractVector,
115-
y::AbstractVector,
120+
x::AbstractArray,
121+
y::AbstractArray,
116122
z,
117123
args...;
118124
suggested_backend::AbstractADType,
@@ -135,21 +141,24 @@ function build_B_aux(
135141
(; conditions, backend, prep_B) = implicit
136142
actual_backend = isnothing(backend) ? suggested_backend : backend
137143
contexts = (Constant(y), Constant(z), map(Constant, args)...)
144+
f_vec = VecToVec(conditions, x)
145+
x_vec = vec(x)
146+
dx_vec = vec(zero(x))
138147
prep_B_same = prepare_pushforward_same_point(
139-
conditions, prep_B..., actual_backend, x, (zero(x),), contexts...
148+
f_vec, prep_B..., actual_backend, x_vec, (dx_vec,), contexts...
140149
)
141-
prod! = JVP!(conditions, prep_B_same, actual_backend, x, contexts)
150+
prod! = JVP!(f_vec, prep_B_same, actual_backend, x_vec, contexts)
142151
return LinearOperator(
143-
eltype(y), length(y), length(x), SYMMETRIC, HERMITIAN, prod!, typeof(x)
152+
eltype(y), length(y), length(x), SYMMETRIC, HERMITIAN, prod!, typeof(x_vec)
144153
)
145154
end
146155

147156
## Bᵀ
148157

149158
function build_Bᵀ(
150159
implicit::ImplicitFunction,
151-
x::AbstractVector,
152-
y::AbstractVector,
160+
x::AbstractArray,
161+
y::AbstractArray,
153162
z,
154163
args...;
155164
suggested_backend::AbstractADType,
@@ -172,11 +181,14 @@ function build_Bᵀ_aux(
172181
(; conditions, backend, prep_Bᵀ) = implicit
173182
actual_backend = isnothing(backend) ? suggested_backend : backend
174183
contexts = (Constant(y), Constant(z), map(Constant, args)...)
184+
f_vec = VecToVec(conditions, x)
185+
x_vec = vec(x)
186+
dc_vec = vec(zero(y))
175187
prep_Bᵀ_same = prepare_pullback_same_point(
176-
conditions, prep_Bᵀ..., actual_backend, x, (zero(y),), contexts...
188+
f_vec, prep_Bᵀ..., actual_backend, x_vec, (dc_vec,), contexts...
177189
)
178-
prod! = VJP!(conditions, prep_Bᵀ_same, actual_backend, x, contexts)
190+
prod! = VJP!(f_vec, prep_Bᵀ_same, actual_backend, x_vec, contexts)
179191
return LinearOperator(
180-
eltype(y), length(x), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(x)
192+
eltype(y), length(x), length(y), SYMMETRIC, HERMITIAN, prod!, typeof(x_vec)
181193
)
182194
end

src/implicit_function.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂
2929
3030
## Positional arguments
3131
32-
- `solver`: a callable returning `(x, args...) -> (y, z)` where `z` is an arbitrary byproduct of the solve. Both `x` and `y` must be subtypes of `AbstractVector`, while `z` and `args` can be anything.
32+
- `solver`: a callable returning `(x, args...) -> (y, z)` where `z` is an arbitrary byproduct of the solve. Both `x` and `y` must be subtypes of `AbstractArray`, while `z` and `args` can be anything.
3333
- `conditions`: a callable returning a vector of optimality conditions `(x, y, z, args...) -> c`, must be compatible with automatic differentiation
3434
3535
## Keyword arguments
@@ -127,6 +127,6 @@ function Base.show(io::IO, implicit::ImplicitFunction)
127127
)
128128
end
129129

130-
function (implicit::ImplicitFunction)(x::AbstractVector, args::Vararg{Any,N}) where {N}
130+
function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N}
131131
return implicit.solver(x, args...)
132132
end

0 commit comments

Comments
 (0)