Skip to content

Commit 474654d

Browse files
authored
fix!: revamp linear operator handling (#172)
* fix!: make `B` a function and not a linear operator or a matrix * Fixes * Add support for LinearMaps * Allow IterativeSolvers, switch to testitems * Add buffer
1 parent 8069c1d commit 474654d

15 files changed

+454
-280
lines changed

Project.toml

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
33
authors = ["Guillaume Dalle", "Mohamed Tarek"]
4-
version = "0.7.3"
4+
version = "0.8.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
9+
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
910
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
1113
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
1214

1315
[weakdeps]
@@ -22,12 +24,30 @@ ImplicitDifferentiationZygoteExt = "Zygote"
2224

2325
[compat]
2426
ADTypes = "1.9.0"
27+
Aqua = "0.8.13"
2528
ChainRulesCore = "1.25.0"
26-
DifferentiationInterface = "0.6.1"
29+
ChainRulesTestUtils = "1.13.0"
30+
ComponentArrays = "0.15.27"
31+
DifferentiationInterface = "0.6.1,0.7"
32+
Documenter = "1.12.0"
33+
ExplicitImports = "1"
34+
FiniteDiff = "2.27.0"
2735
ForwardDiff = "0.10.36, 1"
36+
IterativeSolvers = "0.9.4"
37+
JET = "0.9, 0.10"
38+
JuliaFormatter = "2.1.2"
2839
Krylov = "0.9.6, 0.10"
2940
LinearAlgebra = "1.10"
41+
LinearMaps = "3.11.4"
3042
LinearOperators = "2.8.0"
43+
NLsolve = "4.5.1"
44+
Optim = "1.12.0"
45+
Random = "1"
46+
SparseArrays = "1"
47+
StaticArrays = "1.9.13"
48+
Test = "1"
49+
TestItemRunner = "1.1.0"
50+
TestItems = "1.0.0"
3151
Zygote = "0.7.4"
3252
julia = "1.10"
3353

@@ -51,7 +71,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5171
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5272
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
5373
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
74+
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
75+
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
5476
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5577

5678
[targets]
57-
test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "Zygote"]
79+
test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "TestItems", "TestItemRunner", "Zygote"]

docs/src/api.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
```@meta
2-
CurrentModule = ImplicitDifferentiation
32
CollapsedDocStrings = true
43
```
54

@@ -20,7 +19,7 @@ ImplicitFunction
2019
### Settings
2120

2221
```@docs
23-
KrylovLinearSolver
22+
IterativeLinearSolver
2423
MatrixRepresentation
2524
OperatorRepresentation
2625
NoPreparation

docs/src/faq.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ Say your forward mapping takes multiple inputs and returns multiple outputs, suc
4646
The trick is to leverage [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap all the inputs inside a single a `ComponentVector`, and do the same for all the outputs.
4747
See the examples for a demonstration.
4848

49-
!!! warning "Warning"
49+
!!! warning
50+
The default linear operator representation does not support ComponentArrays.jl: you need to select `representation=OperatorRepresentation{:LinearMaps}()` in the [`ImplicitFunction`](@ref) constructor for it to work.
51+
52+
!!! warning
5053
You may run into issues trying to differentiate through the `ComponentVector` constructor.
5154
For instance, Zygote.jl will throw `ERROR: Mutating arrays is not supported`.
5255
Check out [this issue](https://github.com/gdalle/ImplicitDifferentiation.jl/issues/67) for a dirty workaround involving custom chain rules for the constructor.

examples/3_tricks.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ We demonstrate several features that may come in handy for some users.
77
using ComponentArrays
88
using ForwardDiff
99
using ImplicitDifferentiation
10-
using Krylov
1110
using LinearAlgebra
1211
using Test #src
1312
using Zygote
@@ -43,9 +42,13 @@ function conditions_components(x::ComponentVector, y::ComponentVector, _z)
4342
return c
4443
end;
4544

46-
# And build your implicit function like so.
45+
# And build your implicit function like so, switching the operator representation to avoid errors with ComponentArrays.
4746

48-
implicit_components = ImplicitFunction(forward_components, conditions_components);
47+
implicit_components = ImplicitFunction(
48+
forward_components,
49+
conditions_components;
50+
representation=OperatorRepresentation{:LinearMaps}(),
51+
);
4952

5053
# Now we're good to go.
5154

ext/ImplicitDifferentiationChainRulesCoreExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,18 @@ function ChainRulesCore.rrule(
1717
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
1818
) where {N}
1919
y, z = implicit(x, args...)
20+
c = implicit.conditions(x, y, z, args...)
2021

2122
suggested_backend = chainrules_suggested_backend(rc)
22-
Aᵀ = build_Aᵀ(implicit, x, y, z, args...; suggested_backend)
23-
Bᵀ = build_Bᵀ(implicit, x, y, z, args...; suggested_backend)
23+
Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend)
24+
Bᵀ = build_Bᵀ(implicit, x, y, z, c, args...; suggested_backend)
2425
project_x = ProjectTo(x)
2526

2627
function implicit_pullback((dy, dz))
2728
dy = unthunk(dy)
2829
dy_vec = vec(dy)
2930
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
30-
dx_vec = Bᵀ * dc_vec
31+
dx_vec = Bᵀ(dc_vec)
3132
dx = reshape(dx_vec, size(x))
3233
df = NoTangent()
3334
dargs = ntuple(unimplemented_tangent, N)

ext/ImplicitDifferentiationForwardDiffExt.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,24 @@ function (implicit::ImplicitFunction)(
99
) where {T,R,N}
1010
x = value.(x_and_dx)
1111
y, z = implicit(x, args...)
12+
c = implicit.conditions(x, y, z, args...)
1213

1314
suggested_backend = AutoForwardDiff()
14-
A = build_A(implicit, x, y, z, args...; suggested_backend)
15-
B = build_B(implicit, x, y, z, args...; suggested_backend)
15+
A = build_A(implicit, x, y, z, c, args...; suggested_backend)
16+
B = build_B(implicit, x, y, z, c, args...; suggested_backend)
1617

1718
dX = ntuple(Val(N)) do k
1819
partials.(x_and_dx, k)
1920
end
2021
dC_mat = mapreduce(hcat, dX) do dₖx
2122
dₖx_vec = vec(dₖx)
22-
dₖc_vec = B * dₖx_vec
23+
dₖc_vec = B(dₖx_vec)
24+
return dₖc_vec
2325
end
2426
dY_mat = implicit.linear_solver(A, -dC_mat)
2527

26-
y_and_dy = map(LinearIndices(y)) do i
27-
Dual{T}(y[i], Partials(ntuple(k -> dY_mat[i, k], Val(N))))
28+
y_and_dy = map(y, LinearIndices(y)) do yi, i
29+
Dual{T}(yi, Partials(ntuple(k -> dY_mat[i, k], Val(N))))
2830
end
2931

3032
return y_and_dy, z

src/ImplicitDifferentiation.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@ using DifferentiationInterface:
1818
prepare_pushforward,
1919
prepare_pushforward_same_point,
2020
pullback!,
21-
pushforward!
22-
using Krylov: gmres
21+
pullback,
22+
pushforward!,
23+
pushforward
24+
using Krylov: Krylov
25+
using IterativeSolvers: IterativeSolvers
2326
using LinearOperators: LinearOperator
27+
using LinearMaps: FunctionMap
2428
using LinearAlgebra: factorize
2529

2630
include("utils.jl")
@@ -29,7 +33,7 @@ include("preparation.jl")
2933
include("implicit_function.jl")
3034
include("execution.jl")
3135

32-
export KrylovLinearSolver
36+
export IterativeLinearSolver
3337
export MatrixRepresentation, OperatorRepresentation
3438
export NoPreparation, ForwardPreparation, ReversePreparation, BothPreparation
3539
export ImplicitFunction

0 commit comments

Comments
 (0)