Skip to content

Commit a1d14d3

Browse files
Merge pull request #1053 from AayushSabharwal/as/fix-linprob-remake
feat: implement `remake(::LinearProblem)`
2 parents 2f86159 + 9b4d3ec commit a1d14d3

File tree

3 files changed

+77
-2
lines changed

3 files changed

+77
-2
lines changed

src/problems/linear_problems.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""
22
$(TYPEDEF)
33
4-
A utility struct stored inside `LinearProblem` to enable a symbolic interface.
4+
A utility struct stored inside `LinearProblem` to enable a symbolic interface. Intended for
5+
use by ModelingToolkit.jl.
56
67
# Fields
78

src/remake.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,49 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi
901901
probs, explicitfuns!, f, newp, parameters_alias)
902902
end
903903

904+
function remake(prob::LinearProblem; u0 = missing, p = missing, A = missing, b = missing,
905+
f = missing, interpret_symbolicmap = true, use_defaults = false, kwargs = missing,
906+
_kwargs...)
907+
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
908+
f = coalesce(f, prob.f)
909+
# We want to copy to avoid aliasing, but don't want to unnecessarily copy
910+
A = @coalesce(A, copy(prob.A))
911+
b = @coalesce(b, copy(prob.b))
912+
913+
A, b = _get_new_A_b(f, p, A, b)
914+
915+
if kwargs === missing
916+
return LinearProblem{isinplace(prob)}(A, b, p; u0, f, prob.kwargs..., _kwargs...)
917+
else
918+
return LinearProblem{isinplace(prob)}(A, b, p; u0, f, kwargs...)
919+
end
920+
end
921+
922+
"""
923+
$(TYPEDSIGNATURES)
924+
925+
A helper function to call `get_new_A_b` if `f isa SymbolicLinearInterface`.
926+
"""
927+
_get_new_A_b(f, p, A, b; kw...) = A, b
928+
929+
function _get_new_A_b(f::SymbolicLinearInterface, p, A, b; kw...)
930+
get_new_A_b(f.sys, f, p, A, b; kw...)
931+
end
932+
933+
# public API
934+
"""
935+
$(TYPEDSIGNATURES)
936+
937+
A function to return the updated `A` and `b` matrices for a `LinearProblem` after `remake`.
938+
`root_indp` is the innermost index provider found by recursively, calling
939+
`SymbolicIndexingInterface.symbolic_container`, provided for dispatch. Returns the new `A`
940+
`b` matrices. Mutation of `A` and `b` is permitted.
941+
942+
All implementations must accept arbitrary keyword arguments in case they are added in the
943+
future.
944+
"""
945+
get_new_A_b(root_indp, f, p, A, b; kw...) = A, b
946+
904947
function varmap_has_var(varmap, var)
905948
haskey(varmap, var) || hasname(var) && haskey(varmap, getname(var))
906949
end
@@ -1151,7 +1194,7 @@ function updated_u0_p(
11511194
if u0 === missing && p === missing
11521195
return state_values(prob), parameter_values(prob)
11531196
end
1154-
if has_sys(prob.f) && prob.f.sys === nothing
1197+
if prob.f !== nothing && has_sys(prob.f) && prob.f.sys === nothing
11551198
if interpret_symbolicmap && eltype(p) !== Union{} && eltype(p) <: Pair
11561199
throw(ArgumentError("This problem does not support symbolic maps with " *
11571200
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *

test/remake_tests.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,20 @@ for T in containerTypes
8686
push!(probs, NonlinearLeastSquaresProblem(fn, u0, T(p)))
8787
end
8888

89+
update_A! = function (A, p)
90+
A[1, 1] = p[1]
91+
A[2, 2] = p[2]
92+
A[3, 3] = p[3]
93+
end
94+
update_b! = function (b, p)
95+
b[1] = p[3]
96+
b[2] = -8p[2] - p[1]
97+
end
98+
f = SciMLBase.SymbolicLinearInterface(update_A!, update_b!, indep_sys, nothing, nothing)
99+
for T in containerTypes
100+
push!(probs, LinearProblem(rand(3, 3), rand(3), T(p); u0, f))
101+
end
102+
89103
# temporary definition to test this functionality
90104
function SciMLBase.late_binding_update_u0_p(
91105
prob, u0, p::SciMLBase.NullParameters, t0, newu0, newp)
@@ -429,3 +443,20 @@ end
429443
prob2 = remake(ODEProblem((u, p, t) -> 2 .* u, nothing, nothing); f = f)
430444
@test SciMLBase.specialization(prob2.f) == SciMLBase.FullSpecialize
431445
end
446+
447+
@testset "`remake(::LinearProblem)` without a system" begin
448+
prob = LinearProblem{true}(rand(3, 3), rand(3))
449+
@inferred remake(prob)
450+
base_allocs = @allocations remake(prob)
451+
A = ones(3, 3)
452+
b = ones(3)
453+
u0 = ones(3)
454+
p = "P"
455+
@inferred remake(prob; A, b, u0, p)
456+
@test (@allocations remake(prob; A, b, u0, p)) <= base_allocs
457+
458+
prob2 = remake(prob; u0)
459+
@test prob2.u0 === u0
460+
prob2 = remake(prob; A = SMatrix{3, 3}(A))
461+
@test prob2.A isa SMatrix{3, 3}
462+
end

0 commit comments

Comments
 (0)