Skip to content

Commit b75add5

Browse files
authored
fix: ignore input type in LinearOperator construction by default (#174)
* fix: ignore input type in LinearOperator construction by default * Allow block_gmres
1 parent 74bb967 commit b75add5

File tree

5 files changed

+43
-26
lines changed

5 files changed

+43
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
33
authors = ["Guillaume Dalle", "Mohamed Tarek"]
4-
version = "0.8.0"
4+
version = "0.8.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/faq.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ Say your forward mapping takes multiple inputs and returns multiple outputs, suc
4646
The trick is to leverage [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap all the inputs inside a single a `ComponentVector`, and do the same for all the outputs.
4747
See the examples for a demonstration.
4848

49-
!!! warning
50-
The default linear operator representation does not support ComponentArrays.jl: you need to select `representation=OperatorRepresentation{:LinearMaps}()` in the [`ImplicitFunction`](@ref) constructor for it to work.
51-
5249
!!! warning
5350
You may run into issues trying to differentiate through the `ComponentVector` constructor.
5451
For instance, Zygote.jl will throw `ERROR: Mutating arrays is not supported`.

examples/3_tricks.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,9 @@ function conditions_components(x::ComponentVector, y::ComponentVector, _z)
4242
return c
4343
end;
4444

45-
# And build your implicit function like so, switching the operator representation to avoid errors with ComponentArrays.
45+
# And build your implicit function like so:
4646

47-
implicit_components = ImplicitFunction(
48-
forward_components,
49-
conditions_components;
50-
representation=OperatorRepresentation{:LinearMaps}(),
51-
);
47+
implicit_components = ImplicitFunction(forward_components, conditions_components);
5248

5349
# Now we're good to go.
5450

src/execution.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ function build_A_aux(
6161
end
6262

6363
function build_A_aux(
64-
::OperatorRepresentation{package,symmetric,hermitian},
64+
::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type},
6565
implicit,
6666
x,
6767
y,
6868
z,
6969
c,
7070
args...;
7171
suggested_backend,
72-
) where {package,symmetric,hermitian}
72+
) where {package,symmetric,hermitian,posdef,keep_input_type}
7373
T = Base.promote_eltype(x, y, c)
7474
(; conditions, backends, prep_A) = implicit
7575
actual_backend = isnothing(backends) ? suggested_backend : backends.y
@@ -89,7 +89,13 @@ function build_A_aux(
8989
prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, dy_vec, contexts)
9090
if package == :LinearOperators
9191
return LinearOperator(
92-
T, length(c), length(y), symmetric, hermitian, prod!; S=typeof(dy_vec)
92+
T,
93+
length(c),
94+
length(y),
95+
symmetric,
96+
hermitian,
97+
prod!;
98+
S=keep_input_type ? typeof(dy_vec) : Vector{T},
9399
)
94100
elseif package == :LinearMaps
95101
return FunctionMap{T}(
@@ -99,6 +105,7 @@ function build_A_aux(
99105
ismutating=true,
100106
issymmetric=symmetric,
101107
ishermitian=hermitian,
108+
isposdef=posdef,
102109
)
103110
end
104111
end
@@ -136,15 +143,15 @@ function build_Aᵀ_aux(
136143
end
137144

138145
function build_Aᵀ_aux(
139-
::OperatorRepresentation{package,symmetric,hermitian},
146+
::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type},
140147
implicit,
141148
x,
142149
y,
143150
z,
144151
c,
145152
args...;
146153
suggested_backend,
147-
) where {package,symmetric,hermitian}
154+
) where {package,symmetric,hermitian,posdef,keep_input_type}
148155
T = Base.promote_eltype(x, y, c)
149156
(; conditions, backends, prep_Aᵀ) = implicit
150157
actual_backend = isnothing(backends) ? suggested_backend : backends.y
@@ -164,7 +171,13 @@ function build_Aᵀ_aux(
164171
prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, dc_vec, contexts)
165172
if package == :LinearOperators
166173
return LinearOperator(
167-
T, length(y), length(c), symmetric, hermitian, prod!; S=typeof(dc_vec)
174+
T,
175+
length(y),
176+
length(c),
177+
symmetric,
178+
hermitian,
179+
prod!;
180+
S=keep_input_type ? typeof(dc_vec) : Vector{T},
168181
)
169182
elseif package == :LinearMaps
170183
return FunctionMap{T}(
@@ -174,6 +187,7 @@ function build_Aᵀ_aux(
174187
ismutating=true,
175188
issymmetric=symmetric,
176189
ishermitian=hermitian,
190+
isposdef=posdef,
177191
)
178192
end
179193
end

src/settings.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Callable object that can solve linear systems `Ax = b` and `AX = B` in the same
1212
1313
The type parameter `package` can be either:
1414
15-
- `:Krylov` to use the solver `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) (the default)
15+
- `:Krylov` to use the solver `gmres` or `block_gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) (the default)
1616
- `:IterativeSolvers` to use the solver `gmres` from [IterativeSolvers.jl](https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl)
1717
1818
Keyword arguments are passed on to the respective solver.
@@ -94,36 +94,46 @@ Specify that the matrix `A` involved in the implicit function theorem should be
9494
9595
# Constructors
9696
97-
OperatorRepresentation(; symmetric=false, hermitian=false)
98-
OperatorRepresentation{package}(; symmetric=false, hermitian=false)
97+
OperatorRepresentation(;
98+
symmetric=false, hermitian=false, posdef=false, keep_input_type=false
99+
)
100+
OperatorRepresentation{package}(;
101+
symmetric=false, hermitian=false, posdef=false, keep_input_type=false
102+
)
99103
100104
The type parameter `package` can be either:
101105
102106
- `:LinearOperators` to use a wrapper from [LinearOperators.jl](https://github.com/JuliaSmoothOptimizers/LinearOperators.jl) (the default)
103107
- `:LinearMaps` to use a wrapper from [LinearMaps.jl](https://github.com/JuliaLinearAlgebra/LinearMaps.jl)
104108
105-
The keyword arguments `symmetric` and `hermitian` give additional properties of the Jacobian of the `conditions` with respect to the solution `y`, in case you can prove them.
109+
The keyword arguments `symmetric`, `hermitian` and `posdef` give additional properties of the Jacobian of the `conditions` with respect to the solution `y`, which are useful to the solver in case you can prove them.
110+
111+
The keyword argument `keep_input_type` dictates whether to force the linear operator to work with the provided input type, or fall back on a default.
106112
107113
# See also
108114
109115
- [`ImplicitFunction`](@ref)
110116
- [`MatrixRepresentation`](@ref)
111117
"""
112-
struct OperatorRepresentation{package,symmetric,hermitian} <: AbstractRepresentation
118+
struct OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type} <:
119+
AbstractRepresentation
113120
function OperatorRepresentation{package}(;
114-
symmetric::Bool=false, hermitian::Bool=false
121+
symmetric::Bool=false,
122+
hermitian::Bool=false,
123+
posdef::Bool=false,
124+
keep_input_type::Bool=false,
115125
) where {package}
116126
@assert package in [:LinearOperators, :LinearMaps]
117-
return new{package,symmetric,hermitian}()
127+
return new{package,symmetric,hermitian,posdef,keep_input_type}()
118128
end
119129
end
120130

121131
function Base.show(
122-
io::IO, ::OperatorRepresentation{package,symmetric,hermitian}
123-
) where {package,symmetric,hermitian}
132+
io::IO, ::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type}
133+
) where {package,symmetric,hermitian,posdef,keep_input_type}
124134
return print(
125135
io,
126-
"OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian)",
136+
"OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian, posdef=$posdef, keep_input_type=$keep_input_type)",
127137
)
128138
end
129139

0 commit comments

Comments
 (0)