Skip to content

Commit 8c30ba7

Browse files
authored
revamp!: use KrylovKit for type flexibility (beyond Vector), split out preparation (#180)
* revamp!: use KrylovKit for type flexibility (beyond Vector) * Fix * Add type stability test * Typo * Fix * Fixes * Fix * Strict * Fix coverage * Linear solver settings * More principled printin * Fix tests
1 parent bc7267a commit 8c30ba7

17 files changed

+212
-316
lines changed

Project.toml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@ version = "0.9.0"
66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
9-
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
9+
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11-
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
12-
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
1311

1412
[weakdeps]
1513
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -35,9 +33,8 @@ ForwardDiff = "0.10.36, 1"
3533
JET = "0.9, 0.10"
3634
JuliaFormatter = "2.1.2"
3735
Krylov = "0.9.6, 0.10"
36+
KrylovKit = "0.9.5"
3837
LinearAlgebra = "1"
39-
LinearMaps = "3.11.4"
40-
LinearOperators = "2.8.0"
4138
NLsolve = "4.5.1"
4239
Optim = "1.12.0"
4340
Random = "1"

docs/Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
66
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
88
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
9-
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
1211
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
@@ -16,4 +15,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1615
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1716

1817
[compat]
19-
Documenter = "1.3"
18+
Documenter = "1.3"

docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ MatrixRepresentation
2323
OperatorRepresentation
2424
IterativeLinearSolver
2525
DirectLinearSolver
26-
prepare_implicit
2726
```
2827

2928
## Internals

examples/3_tricks.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ end;
4444

4545
# And build your implicit function like so:
4646

47-
implicit_components = ImplicitFunction(forward_components, conditions_components);
47+
implicit_components = ImplicitFunction(
48+
forward_components, conditions_components; strict=Val(false)
49+
);
4850

4951
# Now we're good to go.
5052

ext/ImplicitDifferentiationChainRulesCoreExt.jl

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,40 +14,43 @@ using ImplicitDifferentiation:
1414
# not covered by Codecov for now
1515
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)
1616

17+
struct ImplicitPullback{TA,TB,TL,TP,Nargs}
18+
Aᵀ::TA
19+
Bᵀ::TB
20+
linear_solver::TL
21+
project_x::TP
22+
_Nargs::Val{Nargs}
23+
end
24+
25+
function (pb::ImplicitPullback{TA,TB,TL,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,Nargs}
26+
(; Aᵀ, Bᵀ, linear_solver, project_x) = pb
27+
dc = linear_solver(Aᵀ, -unthunk(dy))
28+
dx = Bᵀ(dc)
29+
df = NoTangent()
30+
dargs = ntuple(unimplemented_tangent, Val(Nargs))
31+
return (df, project_x(dx), dargs...)
32+
end
33+
1734
function ChainRulesCore.rrule(
18-
rc::RuleConfig,
19-
implicit::ImplicitFunction,
20-
prep::ImplicitFunctionPreparation,
21-
x::AbstractArray,
22-
args::Vararg{Any,N};
35+
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
2336
) where {N}
37+
(; conditions, linear_solver) = implicit
2438
y, z = implicit(x, args...)
25-
c = implicit.conditions(x, y, z, args...)
39+
c = conditions(x, y, z, args...)
2640

2741
suggested_backend = chainrules_suggested_backend(rc)
42+
prep = ImplicitFunctionPreparation(eltype(x))
2843
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
2944
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
3045
project_x = ProjectTo(x)
3146

32-
function implicit_pullback_prepared((dy, dz))
33-
dy = unthunk(dy)
34-
dy_vec = vec(dy)
35-
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
36-
dx_vec = Bᵀ(dc_vec)
37-
dx = reshape(dx_vec, size(x))
38-
df = NoTangent()
39-
dprep = @not_implemented("Tangents for mutable arguments are not defined")
40-
dargs = ntuple(unimplemented_tangent, N)
41-
return (df, dprep, project_x(dx), dargs...)
42-
end
43-
44-
return (y, z), implicit_pullback_prepared
47+
implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, project_x, Val(N))
48+
return (y, z), implicit_pullback
4549
end
4650

4751
function unimplemented_tangent(_)
4852
return @not_implemented(
4953
"Tangents for positional arguments of an `ImplicitFunction` beyond `x` (the first one) are not implemented"
5054
)
5155
end
52-
5356
end

ext/ImplicitDifferentiationForwardDiffExt.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using ImplicitDifferentiation:
66
ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B
77

88
function (implicit::ImplicitFunction)(
9-
prep::ImplicitFunctionPreparation, x_and_dx::AbstractArray{Dual{T,R,N}}, args...
9+
prep::ImplicitFunctionPreparation{R}, x_and_dx::AbstractArray{Dual{T,R,N}}, args...
1010
) where {T,R,N}
1111
x = value.(x_and_dx)
1212
y, z = implicit(x, args...)
@@ -19,14 +19,9 @@ function (implicit::ImplicitFunction)(
1919
dX = ntuple(Val(N)) do k
2020
partials.(x_and_dx, k)
2121
end
22-
dC_vec = map(dX) do dₖx
23-
dₖx_vec = vec(dₖx)
24-
dₖc_vec = B(dₖx_vec)
25-
return dₖc_vec
26-
end
27-
dY = map(dC_vec) do dₖc_vec
28-
dₖy_vec = implicit.linear_solver(A, -dₖc_vec)
29-
dₖy = reshape(dₖy_vec, size(y))
22+
dC = map(B, dX)
23+
dY = map(dC) do dₖc
24+
dₖy = implicit.linear_solver(A, -dₖc)
3025
return dₖy
3126
end
3227

@@ -37,4 +32,11 @@ function (implicit::ImplicitFunction)(
3732
return y_and_dy, z
3833
end
3934

35+
function (implicit::ImplicitFunction)(
36+
x_and_dx::AbstractArray{Dual{T,R,N}}, args...
37+
) where {T,R,N}
38+
prep = ImplicitFunctionPreparation(R)
39+
return implicit(prep, x_and_dx, args...)
40+
end
41+
4042
end

src/ImplicitDifferentiation.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,9 @@ using DifferentiationInterface:
1717
prepare_pullback_same_point,
1818
prepare_pushforward,
1919
prepare_pushforward_same_point,
20-
pullback!,
2120
pullback,
22-
pushforward!,
2321
pushforward
24-
using Krylov: Krylov, krylov_workspace, krylov_solve!, solution
25-
using LinearOperators: LinearOperator
26-
using LinearMaps: FunctionMap
22+
using KrylovKit: linsolve
2723
using LinearAlgebra: factorize
2824

2925
include("utils.jl")
@@ -36,6 +32,5 @@ include("callable.jl")
3632
export MatrixRepresentation, OperatorRepresentation
3733
export IterativeLinearSolver, DirectLinearSolver
3834
export ImplicitFunction
39-
export prepare_implicit
4035

4136
end

src/callable.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N}
2-
return implicit(ImplicitFunctionPreparation(), x, args...)
2+
return implicit(ImplicitFunctionPreparation(eltype(x)), x, args...)
33
end
44

55
function (implicit::ImplicitFunction)(
6-
::ImplicitFunctionPreparation, x::AbstractArray, args::Vararg{Any,N}
7-
) where {N}
6+
::ImplicitFunctionPreparation{R}, x::AbstractArray{R}, args::Vararg{Any,N}
7+
) where {R<:Real,N}
88
return implicit.solver(x, args...)
99
end

0 commit comments

Comments
 (0)