@@ -2,7 +2,9 @@ function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sy
22    nlsys, outer_tmp, inner_tmp =  inner_nlsystem (sys, mm)
33    state =  ProblemState (; u =  u0, p)
44    op =  Dict ()
5-     op[ODE_GAMMA] =  one (eltype (u0))
5+     op[ODE_GAMMA[1 ]] =  one (eltype (u0))
6+     op[ODE_GAMMA[2 ]] =  one (eltype (u0))
7+     op[ODE_GAMMA[3 ]] =  one (eltype (u0))
68    op[ODE_C] =  zero (eltype (u0))
79    op[outer_tmp] =  zeros (eltype (u0), size (outer_tmp))
810    op[inner_tmp] =  zeros (eltype (u0), size (inner_tmp))
@@ -11,15 +13,17 @@ function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sy
1113        op[v] =  getsym (sys, v)(state)
1214    end 
1315    nlprob =  NonlinearProblem (nlsys, op; build_initializeprob =  false )
16+ 
17+     subsetidxs =  [findfirst (isequal (y),unknowns (sys)) for  y in  unknowns (nlsys)]
1418    set_gamma_c =  setsym (nlsys, (ODE_GAMMA... , ODE_C))
1519    set_outer_tmp =  setsym (nlsys, outer_tmp)
1620    set_inner_tmp =  setsym (nlsys, inner_tmp)
1721    nlprobmap =  getsym (nlsys, unknowns (sys))
1822
19-     return  SciMLBase. ODENLStepData (nlprob, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap)
23+     return  SciMLBase. ODENLStepData (nlprob, subsetidxs,  set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap)
2024end 
2125
22- const  ODE_GAMMA =  @parameters  γ₁ₘₜₖ, γ₂ₘₜₖ
26+ const  ODE_GAMMA =  @parameters  γ₁ₘₜₖ, γ₂ₘₜₖ, γ₃ₘₜₖ  
2327const  ODE_C =  only (@parameters  cₘₜₖ)
2428
2529function  get_outer_tmp (n:: Int )
@@ -38,19 +42,19 @@ function inner_nlsystem(sys::System, mm)
3842    @assert  length (eqs) ==  N
3943    @assert  mm ==  I ||  size (mm) ==  (N, N)
4044    rhss =  [eq. rhs for  eq in  eqs]
41-     gamma1, gamma2 =  ODE_GAMMA
45+     gamma1, gamma2, gamma3  =  ODE_GAMMA
4246    c =  ODE_C
4347    outer_tmp =  get_outer_tmp (N)
4448    inner_tmp =  get_inner_tmp (N)
4549
4650    subrules =  Dict ([v =>  gamma2* v +  inner_tmp[i] for  (i, v) in  enumerate (dvs)])
4751    subrules[t] =  t +  c
4852    new_rhss =  map (Base. Fix2 (fast_substitute, subrules), rhss)
49-     new_rhss =  mm  *  dvs  -   gamma1 .*  new_rhss .+    collect (outer_tmp) 
53+     new_rhss =  collect (outer_tmp)  .+   gamma1 .*  new_rhss .-   gamma3  *  mm  *  dvs 
5054    new_eqs =  [0  ~  rhs for  rhs in  new_rhss]
5155
5256    new_dvs =  unknowns (sys)
53-     new_ps =  [parameters (sys); [gamma1, gamma2, c, inner_tmp, outer_tmp]]
57+     new_ps =  [parameters (sys); [gamma1, gamma2, gamma3,  c, inner_tmp, outer_tmp]]
5458    nlsys =  mtkcompile (System (new_eqs, new_dvs, new_ps; name =  :nlsys ); split =  is_split (sys))
5559    return  nlsys, outer_tmp, inner_tmp
5660end 
0 commit comments