@@ -107,53 +107,163 @@ function SciMLBase.solve!(cache::MIRKNCache{iip, T}) where {iip, T}
107107end 
108108
109109function  __perform_mirkn_iteration (cache:: MIRKNCache ; nlsolve_kwargs =  (;), kwargs... )
110-     nlprob:: NonlinearProblem  =  __construct_nlproblem (cache, vec (cache. y₀))
110+     nlprob:: NonlinearProblem  =  __construct_nlproblem (cache, vec (cache. y₀),  copy (cache . y₀) )
111111    nlsolve_alg =  __concrete_nonlinearsolve_algorithm (nlprob, cache. alg. nlsolve)
112112    sol_nlprob =  __solve (nlprob, nlsolve_alg; kwargs... , nlsolve_kwargs... , alias_u0 =  true )
113113    recursive_unflatten! (cache. y₀, sol_nlprob. u)
114114
115115    return  sol_nlprob, sol_nlprob. retcode
116116end 
117117
118- function   __construct_nlproblem (cache :: MIRKNCache{iip} , y)  where  {iip} 
119-     (; jac_alg)  =  cache . alg 
120-     (; diffmode)  =  jac_alg 
118+ #  Constructing the Nonlinear Problem 
119+ function   __construct_nlproblem ( 
120+         cache :: MIRKNCache{iip} , y :: AbstractVector , y₀ :: AbstractVectorOfArray )  where  {iip} 
121121    pt =  cache. problem_type
122122
123+     eval_sol =  EvalSol (
124+         __restructure_sol (y₀. u[1 : length (cache. mesh)], cache. in_size), cache. mesh, cache)
125+     eval_dsol =  EvalSol (
126+         __restructure_sol (y₀. u[(length (cache. mesh) +  1 ): end ], cache. in_size),
127+         cache. mesh, cache)
128+ 
129+     loss_bc =  if  iip
130+         @closure  (du, u, p) ->  __mirkn_loss_bc! (
131+             du, u, p, pt, cache. bc, cache. y, cache. mesh, cache)
132+     else 
133+         @closure  (u, p) ->  __mirkn_loss_bc (u, p, pt, cache. bc, cache. y, cache. mesh, cache)
134+     end 
135+ 
136+     loss_collocation =  if  iip
137+         @closure  (du, u, p) ->  __mirkn_loss_collocation! (
138+             du, u, p, cache. y, cache. mesh, cache. residual, cache)
139+     else 
140+         @closure  (u, p) ->  __mirkn_loss_collocation (
141+             u, p, cache. y, cache. mesh, cache. residual, cache)
142+     end 
143+ 
123144    loss =  if  iip
124145        @closure  (du, u, p) ->  __mirkn_loss! (
125-             du, u, p, cache. y, pt, cache. bc, cache. residual, cache. mesh, cache)
146+             du, u, p, cache. y, pt, cache. bc, cache. residual,
147+             cache. mesh, cache, eval_sol, eval_dsol)
126148    else 
127-         @closure  (u, p) ->  __mirkn_loss (u, p, cache. y, pt, cache. bc, cache. mesh, cache)
149+         @closure  (u, p) ->  __mirkn_loss (
150+             u, p, cache. y, pt, cache. bc, cache. mesh, cache, eval_sol, eval_dsol)
128151    end 
129152
130-     resid_prototype =  __similar (y)
153+     return  __construct_nlproblem (cache, y, loss_bc, loss_collocation, loss, pt)
154+ end 
155+ 
156+ function  __construct_nlproblem (cache:: MIRKNCache{iip} , y, loss_bc:: BC , loss_collocation:: C ,
157+         loss:: LF , :: StandardSecondOrderBVProblem ) where  {iip, BC, C, LF}
158+     (; jac_alg) =  cache. alg
159+     N =  length (cache. mesh)
131160
132-     jac_cache =  if  iip
133-         DI. prepare_jacobian (loss, resid_prototype, diffmode, y, Constant (cache. p))
161+     resid_bc =  cache. bcresid_prototype
162+     L =  length (resid_bc)
163+     resid_collocation =  __similar (y, cache. M *  (2  *  N -  2 ))
164+ 
165+     bc_diffmode =  if  jac_alg. bc_diffmode isa  AutoSparse
166+         AutoSparse (get_dense_ad (jac_alg. bc_diffmode);
167+             sparsity_detector =  __default_sparsity_detector (jac_alg. bc_diffmode),
168+             coloring_algorithm =  __default_coloring_algorithm (jac_alg. bc_diffmode))
134169    else 
135-         DI . prepare_jacobian (loss, diffmode, y,  Constant (cache . p)) 
170+         jac_alg . bc_diffmode 
136171    end 
137172
138-     jac_prototype =  if  iip
139-         DI. jacobian (loss, resid_prototype, jac_cache, diffmode, y, Constant (cache. p))
173+     cache_bc =  if  iip
174+         DI. prepare_jacobian (loss_bc, resid_bc, bc_diffmode, y, Constant (cache. p))
175+     else 
176+         DI. prepare_jacobian (loss_bc, bc_diffmode, y, Constant (cache. p))
177+     end 
178+ 
179+     nonbc_diffmode =  if  jac_alg. nonbc_diffmode isa  AutoSparse
180+         AutoSparse (get_dense_ad (jac_alg. nonbc_diffmode);
181+             sparsity_detector =  __default_sparsity_detector (jac_alg. nonbc_diffmode),
182+             coloring_algorithm =  __default_coloring_algorithm (jac_alg. nonbc_diffmode))
140183    else 
141-         DI . jacobian (loss, jac_cache, diffmode, y,  Constant (cache . p)) 
184+         jac_alg . nonbc_diffmode 
142185    end 
143186
187+     cache_collocation =  if  iip
188+         DI. prepare_jacobian (
189+             loss_collocation, resid_collocation, nonbc_diffmode, y, Constant (cache. p))
190+     else 
191+         DI. prepare_jacobian (loss_collocation, nonbc_diffmode, y, Constant (cache. p))
192+     end 
193+ 
194+     J_bc =  if  iip
195+         DI. jacobian (loss_bc, resid_bc, cache_bc, bc_diffmode, y, Constant (cache. p))
196+     else 
197+         DI. jacobian (loss_bc, cache_bc, bc_diffmode, y, Constant (cache. p))
198+     end 
199+     J_c =  if  iip
200+         DI. jacobian (loss_collocation, resid_collocation, cache_collocation,
201+             nonbc_diffmode, y, Constant (cache. p))
202+     else 
203+         DI. jacobian (
204+             loss_collocation, cache_collocation, nonbc_diffmode, y, Constant (cache. p))
205+     end 
206+ 
207+     jac_prototype =  vcat (J_bc, J_c)
208+ 
144209    jac =  if  iip
145210        @closure  (J, u, p) ->  __mirkn_mpoint_jacobian! (
146-             J, u, diffmode, jac_cache, loss, resid_prototype, cache. p)
211+             J, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation,
212+             loss_bc, loss_collocation, resid_bc, resid_collocation, L, cache. p)
147213    else 
148214        @closure  (u, p) ->  __mirkn_mpoint_jacobian (
149-             jac_prototype, u, diffmode, jac_cache, loss, cache. p)
215+             jac_prototype, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc,
216+             cache_collocation, loss_bc, loss_collocation, L, cache. p)
150217    end 
151- 
218+     resid_prototype  =   vcat (resid_bc, resid_collocation) 
152219    nlf =  NonlinearFunction {iip} (
153220        loss; jac =  jac, resid_prototype =  resid_prototype, jac_prototype =  jac_prototype)
154221    __internal_nlsolve_problem (cache. prob, resid_prototype, y, nlf, y, cache. p)
155222end 
156223
224+ function  __construct_nlproblem (cache:: MIRKNCache{iip} , y, loss_bc:: BC , loss_collocation:: C ,
225+         loss:: LF , :: TwoPointSecondOrderBVProblem ) where  {iip, BC, C, LF}
226+     (; nlsolve, jac_alg) =  cache. alg
227+     N =  length (cache. mesh)
228+ 
229+     resid =  vcat (@view (cache. bcresid_prototype[1 : prod (cache. resid_size[1 ])]),
230+         __similar (y, cache. M *  2  *  (N -  1 )),
231+         @view (cache. bcresid_prototype[(prod (cache. resid_size[1 ]) +  1 ): end ]))
232+ 
233+     diffmode =  if  jac_alg. diffmode isa  AutoSparse
234+         AutoSparse (get_dense_ad (jac_alg. diffmode);
235+             sparsity_detector =  __default_sparsity_detector (jac_alg. diffmode),
236+             coloring_algorithm =  __default_coloring_algorithm (jac_alg. diffmode))
237+     else 
238+         jac_alg. diffmode
239+     end 
240+ 
241+     diffcache =  if  iip
242+         DI. prepare_jacobian (loss, resid, diffmode, y, Constant (cache. p))
243+     else 
244+         DI. prepare_jacobian (loss, diffmode, y, Constant (cache. p))
245+     end 
246+ 
247+     jac_prototype =  if  iip
248+         DI. jacobian (loss, resid, diffcache, diffmode, y, Constant (cache. p))
249+     else 
250+         DI. jacobian (loss, diffcache, diffmode, y, Constant (cache. p))
251+     end 
252+ 
253+     jac =  if  iip
254+         @closure  (J, u, p) ->  __mirkn_2point_jacobian! (
255+             J, u, jac_alg. diffmode, diffcache, loss, resid, p)
256+     else 
257+         @closure  (u, p) ->  __mirkn_2point_jacobian (
258+             u, jac_prototype, jac_alg. diffmode, diffcache, loss, p)
259+     end 
260+ 
261+     resid_prototype =  copy (resid)
262+     nlf =  NonlinearFunction {iip} (
263+         loss; jac =  jac, resid_prototype =  resid_prototype, jac_prototype =  jac_prototype)
264+     return  __internal_nlsolve_problem (cache. prob, resid_prototype, y, nlf, y, cache. p)
265+ end 
266+ 
157267function  __mirkn_2point_jacobian! (J, x, diffmode, diffcache, loss_fn:: L , resid, p) where  {L}
158268    DI. jacobian! (loss_fn, resid, J, diffcache, diffmode, x, Constant (p))
159269    return  J
@@ -164,44 +274,85 @@ function __mirkn_2point_jacobian(x, J, diffmode, diffcache, loss_fn::L, p) where
164274    return  J
165275end 
166276
167- function  __mirkn_mpoint_jacobian! (J, x, diffmode, diffcache, loss, resid, p)
168-     DI. jacobian! (loss, resid, J, diffcache, diffmode, x, Constant (p))
277+ function  __mirkn_mpoint_jacobian! (
278+         J, _, x, bc_diffmode, nonbc_diffmode, bc_diffcache, nonbc_diffcache, loss_bc:: BC ,
279+         loss_collocation:: C , resid_bc, resid_collocation, L:: Int , p) where  {BC, C}
280+     DI. jacobian! (
281+         loss_bc, resid_bc, @view (J[1 : L, :]), bc_diffcache, bc_diffmode, x, Constant (p))
282+     DI. jacobian! (loss_collocation, resid_collocation, @view (J[(L +  1 ): end , :]),
283+         nonbc_diffcache, nonbc_diffmode, x, Constant (p))
169284    return  nothing 
170285end 
171286
172- function  __mirkn_mpoint_jacobian (J, x, diffmode, diffcache, loss, p)
173-     DI. jacobian! (loss, J, diffcache, diffmode, x, Constant (p))
287+ function  __mirkn_mpoint_jacobian (
288+         J, _, x, bc_diffmode, nonbc_diffmode, bc_diffcache, nonbc_diffcache,
289+         loss_bc:: BC , loss_collocation:: C , L:: Int , p) where  {BC, C}
290+     DI. jacobian! (loss_bc, @view (J[1 : L, :]), bc_diffcache, bc_diffmode, x, Constant (p))
291+     DI. jacobian! (loss_collocation, @view (J[(L +  1 ): end , :]),
292+         nonbc_diffcache, nonbc_diffmode, x, Constant (p))
174293    return  J
175294end 
176295
177- @views  function  __mirkn_loss! (resid, u, p, y, pt:: StandardSecondOrderBVProblem ,
178-         bc :: BC ,  residual, mesh, cache:: MIRKNCache ) where  {BC}
296+ @views  function  __mirkn_loss! (resid, u, p, y, pt:: StandardSecondOrderBVProblem , bc :: BC , 
297+         residual, mesh, cache:: MIRKNCache , EvalSol, EvalDSol ) where  {BC}
179298    y_ =  recursive_unflatten! (y, u)
180299    resids =  [get_tmp (r, u) for  r in  residual]
181300    Φ! (resids[3 : end ], cache, y_, u, p)
182-     soly_  =   EvalSol ( 
183-          __restructure_sol (y_ [1 : length (cache . mesh)], cache . in_size),  cache. mesh, cache) 
184-     dsoly_  =   EvalSol ( __restructure_sol (y_[(length (cache. mesh) +  1 ): end ], cache. in_size), 
185-          cache. mesh,  cache) 
186-     eval_bc_residual! (resids[1 : 2 ], pt, bc, soly_, dsoly_ , p, mesh)
301+     EvalSol . u[ 1 : end ] . =   __restructure_sol (y_[ 1 : length (cache . mesh)], cache . in_size) 
302+     EvalSol . cache . k_discrete [1 : end ] . =  cache. k_discrete 
303+     EvalDSol . u[ 1 : end ] . =   __restructure_sol (y_[(length (cache. mesh) +  1 ): end ], cache. in_size)
304+     EvalDSol . cache. k_discrete[ 1 : end ] . =   cache. k_discrete 
305+     eval_bc_residual! (resids[1 : 2 ], pt, bc, EvalSol, EvalDSol , p, mesh)
187306    recursive_flatten! (resid, resids)
188307    return  nothing 
189308end 
190309
191- @views  function  __mirkn_loss (u, p, y, pt:: StandardSecondOrderBVProblem ,
192-         bc :: BC ,  mesh, cache:: MIRKNCache ) where  {BC}
310+ @views  function  __mirkn_loss (u, p, y, pt:: StandardSecondOrderBVProblem , bc :: BC , 
311+         mesh, cache:: MIRKNCache , EvalSol, EvalDSol ) where  {BC}
193312    y_ =  recursive_unflatten! (y, u)
194313    resid_co =  Φ (cache, y_, u, p)
195-     soly_ =  EvalSol (
196-         __restructure_sol (y_[1 : length (cache. mesh)], cache. in_size), cache. mesh, cache)
314+     EvalSol. u[1 : end ] .=  __restructure_sol (y_[1 : length (cache. mesh)], cache. in_size)
315+     EvalSol. cache. k_discrete[1 : end ] .=  cache. k_discrete
316+     EvalDSol. u[1 : end ] .=  __restructure_sol (y_[(length (cache. mesh) +  1 ): end ], cache. in_size)
317+     EvalDSol. cache. k_discrete[1 : end ] .=  cache. k_discrete
318+     resid_bc =  eval_bc_residual (pt, bc, EvalSol, EvalDSol, p, mesh)
319+     return  vcat (resid_bc, mapreduce (vec, vcat, resid_co))
320+ end 
321+ 
322+ @views  function  __mirkn_loss_bc! (
323+         resid, u, p, pt, bc!:: BC , y, mesh, cache:: MIRKNCache ) where  {BC}
324+     y_ =  recursive_unflatten! (y, u)
325+     soly_ =  EvalSol (__restructure_sol (y_[1 : length (cache. mesh)], cache. in_size), mesh, cache)
197326    dsoly_ =  EvalSol (__restructure_sol (y_[(length (cache. mesh) +  1 ): end ], cache. in_size),
198327        cache. mesh, cache)
199-     resid_bc =  eval_bc_residual (pt, bc, soly_, dsoly_, p, mesh)
200-     return  vcat (resid_bc, mapreduce (vec, vcat, resid_co))
328+     eval_bc_residual! (resid, pt, bc!, soly_, dsoly_, p, mesh)
329+     return  nothing 
330+ end 
331+ 
332+ @views  function  __mirkn_loss_bc (u, p, pt, bc!:: BC , y, mesh, cache:: MIRKNCache ) where  {BC}
333+     y_ =  recursive_unflatten! (y, u)
334+     soly_ =  EvalSol (__restructure_sol (y_[1 : length (cache. mesh)], cache. in_size), mesh, cache)
335+     dsoly_ =  EvalSol (__restructure_sol (y_[(length (cache. mesh) +  1 ): end ], cache. in_size),
336+         cache. mesh, cache)
337+     return  eval_bc_residual (pt, bc!, soly_, dsoly_, p, mesh)
338+ end 
339+ 
340+ @views  function  __mirkn_loss_collocation! (resid, u, p, y, mesh, residual, cache)
341+     y_ =  recursive_unflatten! (y, u)
342+     resids =  [get_tmp (r, u) for  r in  residual[3 : end ]]
343+     Φ! (resids, cache, y_, u, p)
344+     recursive_flatten! (resid, resids)
345+     return  nothing 
346+ end 
347+ 
348+ @views  function  __mirkn_loss_collocation (u, p, y, mesh, residual, cache)
349+     y_ =  recursive_unflatten! (y, u)
350+     resids =  Φ (cache, y_, u, p)
351+     return  mapreduce (vec, vcat, resids)
201352end 
202353
203- @views  function  __mirkn_loss! (resid, u, p, y, pt:: TwoPointSecondOrderBVProblem ,
204-         bc! :: BC ,  residual, mesh, cache:: MIRKNCache ) where  {BC}
354+ @views  function  __mirkn_loss! (resid, u, p, y, pt:: TwoPointSecondOrderBVProblem , bc! :: BC , 
355+         residual, mesh, cache:: MIRKNCache , _, _ ) where  {BC}
205356    y_ =  recursive_unflatten! (y, u)
206357    resids =  [get_tmp (r, u) for  r in  residual]
207358    soly_ =  VectorOfArray (y_)
212363end 
213364
214365@views  function  __mirkn_loss (u, p, y, pt:: TwoPointSecondOrderBVProblem ,
215-         bc!:: BC , mesh, cache:: MIRKNCache ) where  {BC}
366+         bc!:: BC , mesh, cache:: MIRKNCache , _, _ ) where  {BC}
216367    y_ =  recursive_unflatten! (y, u)
217368    soly_ =  VectorOfArray (y_)
218369    resid_co =  Φ (cache, y_, u, p)
0 commit comments