Skip to content

Commit f95b799

Browse files
committed
format
1 parent f706c8f commit f95b799

File tree

2 files changed

+45
-63
lines changed

2 files changed

+45
-63
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module LinearSolveForwardDiffExt
1+
module LinearSolveForwardDiffExt
22

33
using LinearSolve
44
using ForwardDiff
@@ -7,50 +7,49 @@ using SciMLBase
77
using RecursiveArrayTools
88

99
const DualLinearProblem = LinearProblem{
10-
<:Union{Number,<:AbstractArray, Nothing},iip,
11-
<:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}},
12-
<:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}},
13-
<:Union{Number,<:AbstractArray, SciMLBase.NullParameters}
10+
<:Union{Number, <:AbstractArray, Nothing}, iip,
11+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
12+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
13+
<:Union{Number, <:AbstractArray, SciMLBase.NullParameters}
1414
} where {iip, T, V, P}
1515

16-
1716
const DualALinearProblem = LinearProblem{
18-
<:Union{Number,<:AbstractArray, Nothing},
17+
<:Union{Number, <:AbstractArray, Nothing},
1918
iip,
20-
<:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}},
21-
<:Union{Number,<:AbstractArray},
22-
<:Union{Number,<:AbstractArray, SciMLBase.NullParameters}
19+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
20+
<:Union{Number, <:AbstractArray},
21+
<:Union{Number, <:AbstractArray, SciMLBase.NullParameters}
2322
} where {iip, T, V, P}
2423

2524
const DualBLinearProblem = LinearProblem{
26-
<:Union{Number,<:AbstractArray, Nothing},
25+
<:Union{Number, <:AbstractArray, Nothing},
2726
iip,
28-
<:Union{Number,<:AbstractArray},
29-
<:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}},
30-
<:Union{Number,<:AbstractArray, SciMLBase.NullParameters}
27+
<:Union{Number, <:AbstractArray},
28+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
29+
<:Union{Number, <:AbstractArray, SciMLBase.NullParameters}
3130
} where {iip, T, V, P}
3231

33-
const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem}
32+
const DualAbstractLinearProblem = Union{
33+
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
3434

3535
function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...)
3636
@info "here!"
3737
new_A = nodual_value(prob.A)
3838
new_b = nodual_value(prob.b)
3939

40-
newprob = remake(prob; A=new_A, b=new_b)
40+
newprob = remake(prob; A = new_A, b = new_b)
4141

4242
sol = solve(newprob, alg, args...; kwargs...)
4343
uu = sol.u
4444

45-
4645
# Solves Dual partials separately
4746
∂_A = partial_vals(prob.A)
4847
∂_b = partial_vals(prob.b)
4948

5049
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
5150

5251
partial_sols = map(rhs_list) do rhs
53-
partial_prob = remake(newprob, b=rhs)
52+
partial_prob = remake(newprob, b = rhs)
5453
solve(partial_prob, alg, args...; kwargs...).u
5554
end
5655

@@ -66,7 +65,8 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
6665
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
6766
end
6867

69-
function SciMLBase.solve(prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
68+
function SciMLBase.solve(prob::DualAbstractLinearProblem,
69+
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
7070
sol, partials = linearsolve_forwarddiff_solve(
7171
prob, alg, args...; kwargs...
7272
)
@@ -82,28 +82,24 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, alg::LinearSolve.SciML
8282
return SciMLBase.build_linear_solution(
8383
alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats
8484
)
85-
86-
8785
end
8886

89-
9087
function linearsolve_dual_solution(
91-
u::Number, partials, dual_type)
88+
u::Number, partials, dual_type)
9289
return dual_type(u, partials)
9390
end
9491

9592
function linearsolve_dual_solution(
96-
u::AbstractArray, partials, dual_type)
93+
u::AbstractArray, partials, dual_type)
9794
partials_list = RecursiveArrayTools.VectorOfArray(partials)
98-
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
95+
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
96+
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
9997
end
10098

101-
10299
get_dual_type(x::Dual) = typeof(x)
103100
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x)
104101
get_dual_type(x) = nothing
105102

106-
107103
partial_vals(x::Dual) = ForwardDiff.partials(x)
108104
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
109105
partial_vals(x) = nothing
@@ -112,59 +108,51 @@ nodual_value(x) = x
112108
nodual_value(x::Dual) = ForwardDiff.value(x)
113109
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
114110

115-
116-
function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
111+
function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
112+
∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
117113
A_list = partials_to_list(∂_A)
118-
b_list = partials_to_list(∂_b)
114+
b_list = partials_to_list(∂_b)
119115

120-
Auu = [A*uu for A in A_list]
116+
Auu = [A * uu for A in A_list]
121117

122118
b_list .- Auu
123119
end
124120

125-
function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing)
121+
function xp_linsolve_rhs(
122+
uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing)
126123
A_list = partials_to_list(∂_A)
127124

128-
Auu = [A*uu for A in A_list]
125+
Auu = [A * uu for A in A_list]
129126

130127
Auu
131128
end
132129

133-
function xp_linsolve_rhs(uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
130+
function xp_linsolve_rhs(
131+
uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
134132
b_list = partials_to_list(∂_b)
135133

136134
b_list
137135
end
138136

139-
140-
141137
function partials_to_list(partial_matrix::Vector)
142138
p = eachindex(first(partial_matrix))
143-
[[partial[i] for partial in partial_matrix] for i in p]
139+
[[partial[i] for partial in partial_matrix] for i in p]
144140
end
145141

146142
function partials_to_list(partial_matrix)
147143
p = length(first(partial_matrix))
148-
m,n = size(partial_matrix)
149-
res_list = fill(zeros(m,n),p)
144+
m, n = size(partial_matrix)
145+
res_list = fill(zeros(m, n), p)
150146
for k in 1:p
151-
res = zeros(m,n)
147+
res = zeros(m, n)
152148
for i in 1:m
153149
for j in 1:n
154-
res[i,j] = partial_matrix[i,j][k]
150+
res[i, j] = partial_matrix[i, j][k]
155151
end
156152
end
157153
res_list[k] = res
158154
end
159155
return res_list
160156
end
161157

162-
end
163-
164-
165-
166-
167-
168-
169-
170-
158+
end

test/forwarddiff_overloads.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,26 @@ using LinearSolve
22
using ForwardDiff
33
using Test
44

5-
65
function h(p)
7-
(A=[p[1] p[2]+1 p[2]^3;
8-
3*p[1] p[1]+5 p[2]*p[1]-4;
9-
p[2]^2 9*p[1] p[2]],
10-
b=[p[1] + 1, p[2] * 2, p[1]^2])
6+
(A = [p[1] p[2]+1 p[2]^3;
7+
3*p[1] p[1]+5 p[2] * p[1]-4;
8+
p[2]^2 9*p[1] p[2]],
9+
b = [p[1] + 1, p[2] * 2, p[1]^2])
1110
end
1211

1312
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
1413

1514
prob = LinearProblem(A, b)
1615
overload_x_p = solve(prob)
17-
original_x_p = solve!(init(prob))
16+
original_x_p = solve!(init(prob))
1817

1918
@test overload_x_p original_x_p
2019

21-
2220
A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
2321
prob = LinearProblem(A, [6.0, 10.0, 25.0])
2422
@test solve(prob).retcode == ReturnCode.Default
2523

2624
_, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
2725
A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0]
28-
prob = LinearProblem(A,b)
29-
@test solve(prob).retcode == ReturnCode.Default
30-
31-
32-
33-
26+
prob = LinearProblem(A, b)
27+
@test solve(prob).retcode == ReturnCode.Default

0 commit comments

Comments
 (0)