Skip to content

Commit 57e4778

Browse files
fix: call u0_constructor on resid_prototype
1 parent 52ce641 commit 57e4778

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -345,16 +345,6 @@ function hessian_sparsity(sys::NonlinearSystem)
345345
unknowns(sys)) for eq in equations(sys)]
346346
end
347347

348-
function calculate_resid_prototype(N, u0, p)
349-
u0ElType = u0 === nothing ? Float64 : eltype(u0)
350-
if SciMLStructures.isscimlstructure(p)
351-
u0ElType = promote_type(
352-
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
353-
u0ElType)
354-
end
355-
return zeros(u0ElType, N)
356-
end
357-
358348
"""
359349
```julia
360350
SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
@@ -381,6 +371,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
381371
eval_module = @__MODULE__,
382372
sparse = false, simplify = false,
383373
initialization_data = nothing, cse = true,
374+
resid_prototype = nothing,
384375
kwargs...) where {iip}
385376
if !iscomplete(sys)
386377
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearFunction`")
@@ -402,12 +393,6 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
402393
observedfun = ObservedFunctionCache(
403394
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse)
404395

405-
if length(dvs) == length(equations(sys))
406-
resid_prototype = nothing
407-
else
408-
resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p)
409-
end
410-
411396
NonlinearFunction{iip}(f;
412397
sys = sys,
413398
jac = _jac === nothing ? nothing : _jac,

src/systems/problem_utils.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,22 @@ function float_type_from_varmap(varmap, floatT = Bool)
10821082
return float(floatT)
10831083
end
10841084

1085+
"""
1086+
$(TYPEDSIGNATURES)
1087+
1088+
Calculate the `resid_prototype` for a `NonlinearFunction` with `N` equations and the
1089+
provided `u0` and `p`.
1090+
"""
1091+
function calculate_resid_prototype(N::Int, u0, p)
1092+
u0ElType = u0 === nothing ? Float64 : eltype(u0)
1093+
if SciMLStructures.isscimlstructure(p)
1094+
u0ElType = promote_type(
1095+
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
1096+
u0ElType)
1097+
end
1098+
return zeros(u0ElType, N)
1099+
end
1100+
10851101
"""
10861102
$(TYPEDSIGNATURES)
10871103
@@ -1293,7 +1309,14 @@ function process_SciMLProblem(
12931309
end
12941310
initialization_data = SciMLBase.remake_initialization_data(
12951311
kwargs.initialization_data, kwargs, u0, t0, p, u0, p)
1296-
kwargs = merge(kwargs,)
1312+
kwargs = merge(kwargs, (; initialization_data))
1313+
end
1314+
1315+
if constructor <: NonlinearFunction && length(dvs) != length(eqs)
1316+
kwargs = merge(kwargs,
1317+
(;
1318+
resid_prototype = u0_constructor(calculate_resid_prototype(
1319+
length(eqs), u0, p))))
12971320
end
12981321

12991322
f = constructor(sys, dvs, ps, u0; p = p,

test/nonlinearsystem.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,19 @@ end
442442
@test !in(D(y), vs)
443443
end
444444
end
445+
446+
@testset "oop `NonlinearLeastSquaresProblem` with `u0 === nothing`" begin
447+
@variables x y
448+
@named sys = NonlinearSystem([0 ~ x - y], [], []; observed = [x ~ 1.0, y ~ 1.0])
449+
prob = NonlinearLeastSquaresProblem{false}(complete(sys), nothing)
450+
sol = solve(prob)
451+
resid = sol.resid
452+
@test resid == [0.0]
453+
@test resid isa Vector
454+
prob = NonlinearLeastSquaresProblem{false}(
455+
complete(sys), nothing; u0_constructor = splat(SVector))
456+
sol = solve(prob)
457+
resid = sol.resid
458+
@test resid == [0.0]
459+
@test resid isa SVector
460+
end

0 commit comments

Comments
 (0)