@@ -39,14 +39,12 @@ function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs
39
39
∂_A = partial_vals (A)
40
40
∂_b = partial_vals (b)
41
41
42
-
42
+ rhs = xp_linsolve_rhs (uu, ∂_A, ∂_b)
43
43
44
- if uu isa Number
45
-
46
- else
47
-
48
- end
44
+ partial_prob = remake (newprob, b = rhs)
45
+ partial_sol = solve (partial_prob, alg, args... ; kwargs... )
49
46
47
+ sol, partial_sol
50
48
end
51
49
52
50
@@ -60,33 +58,27 @@ nodual_value(x::Dual) = ForwardDiff.value(x)
60
58
nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
61
59
62
60
63
- function x_p_linsolve (new_A, uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} )
61
+ function xp_linsolve_rhs ( uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} )
64
62
A_list = partials_to_list (∂_A)
65
63
b_list = partials_to_list (∂_b)
66
64
67
65
Auu = [A* uu for A in A_list]
68
66
69
- linsol_rhs = reduce (hcat, b_list .- Auu)
70
-
71
- new_A \ linsol_rhs
67
+ reduce (hcat, b_list .- Auu)
72
68
end
73
69
74
- function x_p_linsolve (new_A, uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} , ∂_b:: Nothing )
70
+ function xp_linsolve_rhs ( uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} , ∂_b:: Nothing )
75
71
A_list = partials_to_list (∂_A)
76
72
77
73
Auu = [A* uu for A in A_list]
78
74
79
- linsol_rhs = reduce (hcat, Auu)
80
-
81
- new_A \ linsol_rhs
75
+ reduce (hcat, Auu)
82
76
end
83
77
84
- function x_p_linsolve (new_A, uu, ∂_A:: Nothing , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} )
78
+ function xp_linsolve_rhs ( uu, ∂_A:: Nothing , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} )
85
79
b_list = partials_to_list (∂_b)
86
80
87
- linsol_rhs = reduce (hcat, b_list)
88
-
89
- new_A \ linsol_rhs
81
+ reduce (hcat, b_list)
90
82
end
91
83
92
84
0 commit comments