1010
1111function  CacheWriter (sys:: AbstractSystem , buffer_types:: Vector{TypeT} ,
1212        exprs:: Dict{TypeT, Vector{Any}} , solsyms, obseqs:: Vector{Equation} ;
13-         eval_expression =  false , eval_module =  @__MODULE__ , cse =  true )
13+         eval_expression =  false , eval_module =  @__MODULE__ , cse =  true , sparse  =   false )
1414    ps =  parameters (sys; initial_parameters =  true )
1515    rps =  reorder_parameters (sys, ps)
1616    obs_assigns =  [eq. lhs ←  eq. rhs for  eq in  obseqs]
3939struct  SCCNonlinearFunction{iip} end 
4040
4141function  SCCNonlinearFunction {iip} (
42-         sys:: System , _eqs, _dvs, _obs, cachesyms; eval_expression =  false ,
42+         sys:: System , _eqs, _dvs, _obs, cachesyms, op ; eval_expression =  false ,
4343        eval_module =  @__MODULE__ , cse =  true , kwargs... ) where  {iip}
4444    ps =  parameters (sys; initial_parameters =  true )
45+     subsys =  System (
46+         _eqs, _dvs, ps; observed =  _obs, name =  nameof (sys), defaults =  defaults (sys))
47+     @set!  subsys. parameter_dependencies =  parameter_dependencies (sys)
48+     if  get_index_cache (sys) != =  nothing 
49+         @set!  subsys. index_cache =  subset_unknowns_observed (
50+             get_index_cache (sys), sys, _dvs, getproperty .(_obs, (:lhs ,)))
51+         @set!  subsys. complete =  true 
52+     end 
53+     #  generate linear problem instead
54+     if  isaffine (subsys)
55+         return  LinearFunction {iip} (
56+             subsys; eval_expression, eval_module, cse, cachesyms, kwargs... )
57+     end 
4558    rps =  reorder_parameters (sys, ps)
4659
4760    obs_assignments =  [eq. lhs ←  eq. rhs for  eq in  _obs]
@@ -54,14 +67,6 @@ function SCCNonlinearFunction{iip}(
5467    f_oop, f_iip =  eval_or_rgf .(f_gen; eval_expression, eval_module)
5568    f =  GeneratedFunctionWrapper {(2, 2, is_split(sys))} (f_oop, f_iip)
5669
57-     subsys =  System (_eqs, _dvs, ps; observed =  _obs,
58-         parameter_dependencies =  parameter_dependencies (sys), name =  nameof (sys))
59-     if  get_index_cache (sys) != =  nothing 
60-         @set!  subsys. index_cache =  subset_unknowns_observed (
61-             get_index_cache (sys), sys, _dvs, getproperty .(_obs, (:lhs ,)))
62-         @set!  subsys. complete =  true 
63-     end 
64- 
6570    return  NonlinearFunction {iip} (f; sys =  subsys)
6671end 
6772
@@ -70,7 +75,7 @@ function SciMLBase.SCCNonlinearProblem(sys::System, args...; kwargs...)
7075end 
7176
7277function  SciMLBase. SCCNonlinearProblem {iip} (sys:: System , op; eval_expression =  false ,
73-         eval_module =  @__MODULE__ , cse =  true , kwargs... ) where  {iip}
78+         eval_module =  @__MODULE__ , cse =  true , u0_constructor  =  identity,  kwargs... ) where  {iip}
7479    if  ! iscomplete (sys) ||  get_tearing_state (sys) ===  nothing 
7580        error (" A simplified `System` is required. Call `mtkcompile` on the system before creating an `SCCNonlinearProblem`." 
7681    end 
@@ -113,7 +118,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
113118
114119    _, u0,
115120    p =  process_SciMLProblem (
116-         EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, kwargs... )
121+         EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, symbolic_u0  =   true ,  kwargs... )
117122
118123    explicitfuns =  []
119124    nlfuns =  []
@@ -224,28 +229,57 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
224229            get (cachevars, T, [])
225230        end )
226231        f =  SCCNonlinearFunction {iip} (
227-             sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, cse, kwargs... )
232+             sys, _eqs, _dvs, _obs, cachebufsyms, op;
233+             eval_expression, eval_module, cse, kwargs... )
228234        push! (nlfuns, f)
229235    end 
230236
237+     u0_eltype =  Union{}
238+     for  x in  u0
239+         symbolic_type (x) ==  NotSymbolic () ||  continue 
240+         u0_eltype =  typeof (x)
241+         break 
242+     end 
243+     if  u0_eltype ==  Union{}
244+         u0_eltype =  Float64
245+     end 
246+     u0_eltype =  float (u0_eltype)
247+ 
231248    if  ! isempty (cachetypes)
232249        templates =  map (cachetypes, cachesizes) do  T, n
233250            #  Real refers to `eltype(u0)`
234251            if  T ==  Real
235-                 T =  eltype (u0) 
252+                 T =  u0_eltype 
236253            elseif  T <:  Array  &&  eltype (T) ==  Real
237-                 T =  Array{eltype (u0) , ndims (T)}
254+                 T =  Array{u0_eltype , ndims (T)}
238255            end 
239256            BufferTemplate (T, n)
240257        end 
241258        p =  rebuild_with_caches (p, templates... )
242259    end 
243260
261+     #  yes, `get_p_constructor` since this is only used for `LinearProblem` and
262+     #  will retain the shape of `A`
263+     u0_constructor =  get_p_constructor (u0_constructor, typeof (u0), u0_eltype)
244264    subprobs =  []
245-     for  (f, vscc) in  zip (nlfuns, var_sccs)
265+     for  (i, ( f, vscc))  in  enumerate ( zip (nlfuns, var_sccs) )
246266        _u0 =  SymbolicUtils. Code. create_array (
247267            typeof (u0), eltype (u0), Val (1 ), Val (length (vscc)), u0[vscc]. .. )
248-         prob =  NonlinearProblem (f, _u0, p)
268+         symbolic_idxs =  findall (x ->  symbolic_type (x) !=  NotSymbolic (), _u0)
269+         explicitfuns[i](p, subprobs)
270+         if  f isa  LinearFunction
271+             _u0 =  isempty (symbolic_idxs) ?  _u0 :  zeros (u0_eltype, length (_u0))
272+             _u0 =  u0_eltype .(_u0)
273+             symbolic_interface =  f. interface
274+             A,
275+             b =  get_A_b_from_LinearFunction (
276+                 sys, f, p; eval_expression, eval_module, u0_constructor, u0_eltype)
277+             prob =  LinearProblem {iip} (A, b, p; f =  symbolic_interface, u0 =  _u0)
278+         else 
279+             isempty (symbolic_idxs) ||  throw (MissingGuessError (dvs[vscc], _u0))
280+             _u0 =  u0_eltype .(_u0)
281+             prob =  NonlinearProblem (f, _u0, p)
282+         end 
249283        push! (subprobs, prob)
250284    end 
251285
@@ -255,5 +289,5 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
255289    @set!  sys. eqs =  new_eqs
256290    @set!  sys. index_cache =  subset_unknowns_observed (
257291        get_index_cache (sys), sys, new_dvs, getproperty .(obs, (:lhs ,)))
258-     return  SCCNonlinearProblem (subprobs,  explicitfuns, p, true ; sys)
292+     return  SCCNonlinearProblem (Tuple ( subprobs),  Tuple ( explicitfuns) , p, true ; sys)
259293end 
0 commit comments