@@ -317,20 +317,20 @@ MockChangeModel = function() {
317317
318318# #' State Updates
319319# #'
320- # #' Use these functions to update a model spec so that the state variables
321- # #' are updated according to different numerical methods.
320+ # #' These functions return a modified version of a model specification, such
321+ # #' that the state variables are updated each time step according to different
322+ # #' numerical methods.
323+ # #'
324+ # #' To see the computations that update the state variables under these
325+ # #' modified specifications, one may use the
326+ # #' \code{\link{mp_expand}} function (see examples).
322327# #'
323328# #' The default update method for model specifications produced using
324329# #' \code{\link{mp_tmb_model_spec}} is `mp_euler`. This update method
325330# #' yields a difference-equation model where the state is updated once
326331# #' per time-step using the absolute flow rate as the difference between
327332# #' steps.
328333# #'
329- # #' These state update functions are used to modify a model specification to
330- # #' use a particular kind of state update. To see these modified models for
331- # #' a particular example one may use the \code{\link{mp_expand}} function
332- # #' (see examples).
333- # #'
334334# #' @param model Object with quantities that have been explicitly
335335# #' marked as state variables.
336336# #'
@@ -341,10 +341,18 @@ MockChangeModel = function() {
341341# #' sir |> mp_rk4() |> mp_expand()
342342# #' sir |> mp_euler_multinomial() |> mp_expand()
343343# #'
344+ # #' @name state_updates
345+ NULL
346+
347+ # #' @describeIn state_updates ODE solver using the Euler method, which is
348+ # #' equivalent to treating the model as a set of discrete-time difference
349+ # #' equations. This is the default method used by
350+ # #' \code{\link{mp_tmb_model_spec}}, but this default can be changed using
351+ # #' the functions described below.
344352# #' @export
345353mp_euler = function (model ) UseMethod(" mp_euler" )
346354
347- # #' @describeIn mp_euler ODE solver using Runge-Kutta 4. Any formulas that
355+ # #' @describeIn state_updates ODE solver using Runge-Kutta 4. Any formulas that
348356# #' appear before model flows in the `during` list will only be updated
349357# #' with RK4 if they do contain functions in
350358# #' `getOption("macpan2_non_iterable_funcs")` and if they do not make any
@@ -362,17 +370,24 @@ mp_euler = function(model) UseMethod("mp_euler")
362370# #' will only be called once per time-step, and so it should never be removed
363371# #' from the list of non-iterable functions. Although in principle it could
364372# #' make sense to update state variables manually, it currently causes us to
365- # #' be confused. We therefore require that all state variables updates are set
366- # #' explicitly (e.g., with \code{\link{mp_per_capita_flow}}) if any are explicit .
373+ # #' be confused. We therefore require that all state variable updates are set
374+ # #' explicitly (e.g., with \code{\link{mp_per_capita_flow}}).
367375# #' @export
368376mp_rk4 = function (model ) UseMethod(" mp_rk4" )
369377
370- # #' @describeIn mp_euler Update state with process error given by the
378+ # #' @describeIn state_updates Old version of `mp_rk4` that doesn't keep track
379+ # #' of absolute flows through each time-step. As a result this version is
380+ # #' more efficient but makes it more difficult to compute things like
381+ # #' incidence over a time scale.
382+ # #' @export
383+ mp_rk4_old = function (model ) UseMethod(" mp_rk4_old" )
384+
385+ # #' @describeIn state_updates Update state with process error given by the
371386# #' Euler-multinomial distribution.
372387# #' @export
373388mp_euler_multinomial = function (model ) UseMethod(" mp_euler_multinomial" )
374389
375- # #' @describeIn mp_euler Update state with hazard steps, which is equivalent
390+ # #' @describeIn state_updates Update state with hazard steps, which is equivalent
376391# #' to taking the step given by the expected value of the Euler-multinomial
377392# #' distribution.
378393# #' @export
@@ -384,6 +399,10 @@ mp_euler.TMBModelSpec = function(model) model$change_update_method("euler")
384399# #' @export
385400mp_rk4.TMBModelSpec = function (model ) model $ change_update_method(" rk4" )
386401
402+ # #' @export
403+ mp_rk4_old.TMBModelSpec = function (model ) model $ change_update_method(" rk4_old" )
404+
405+
387406# #' @export
388407mp_euler_multinomial.TMBModelSpec = function (model ) model $ change_update_method(" euler_multinomial" )
389408
@@ -446,6 +465,7 @@ get_state_update_method = function(state_update, change_model) {
446465 }
447466 cls_nm = sprintf(" %sUpdateMethod" , var_case_to_cls_case(state_update ))
448467 if (state_update == " rk4" ) cls_nm = " RK4UpdateMethod"
468+ if (state_update == " rk4_old" ) cls_nm = " RK4OldUpdateMethod"
449469 get(cls_nm )(change_model )
450470}
451471get_change_model = function (before , during , after ) {
@@ -490,9 +510,7 @@ EulerUpdateMethod = function(change_model, existing_global_names = character())
490510 self = UpdateMethod()
491511 self $ change_model = change_model
492512
493- # # nb: euler method requires no additional names from the spec
494- # self$existing_global_names = existing_global_names
495-
513+ # # euler method requires no additional names from the spec
496514
497515 self $ before = function () self $ change_model $ before_loop()
498516 self $ during = function () {
@@ -517,7 +535,7 @@ EulerUpdateMethod = function(change_model, existing_global_names = character())
517535}
518536
519537
520- RK4UpdateMethod = function (change_model ) {
538+ RK4OldUpdateMethod = function (change_model ) {
521539 self = UpdateMethod()
522540 self $ change_model = change_model
523541
@@ -580,6 +598,97 @@ RK4UpdateMethod = function(change_model) {
580598 return_object(self , " EulerUpdateMethod" )
581599}
582600
601+ RK4UpdateMethod = function (change_model ) {
602+ self = UpdateMethod()
603+ self $ change_model = change_model
604+
605+ self $ before = function () self $ change_model $ before_loop()
606+ self $ during = function () {
607+ before_components = self $ change_model $ before_flows()
608+ flow_frame = self $ change_model $ flow_frame()
609+ components = flow_frame_to_absolute_flows(flow_frame )
610+ before_state = self $ change_model $ before_state()
611+ # before = c(before_components, components, before_state)
612+ update_state = self $ change_model $ update_state()
613+ update_flows = self $ change_model $ update_flows() | > unlist(recursive = FALSE , use.names = FALSE )
614+
615+ new_update = list ()
616+ new_before = list ()
617+
618+ states = vapply(update_state , lhs_char , character (1L ))
619+ rates = vapply(update_state , rhs_char , character (1L ))
620+ flows = flow_frame $ change
621+
622+ existing_names = self $ change_model $ all_user_aware_names()
623+ local_state_step_names = list (
624+ k1 = sprintf(" k1_%s" , states )
625+ , k2 = sprintf(" k2_%s" , states )
626+ , k3 = sprintf(" k3_%s" , states )
627+ , k4 = sprintf(" k4_%s" , states )
628+ )
629+ local_flow_step_names = list (
630+ k1 = sprintf(" k1_%s" , flows )
631+ , k2 = sprintf(" k2_%s" , flows )
632+ , k3 = sprintf(" k3_%s" , flows )
633+ , k4 = sprintf(" k4_%s" , flows )
634+ )
635+ state_step_names = map_names(existing_names , local_state_step_names )
636+ flow_step_names = map_names(existing_names , local_flow_step_names )
637+
638+ rate_formulas = sprintf(" %s ~ %s" , state_step_names $ k1 , rates ) | > lapply(as.formula )
639+
640+ make_before = function (stage ) {
641+ stage_flow_frame = within(flow_frame , change <- flow_step_names [[stage ]])
642+ stage_components = macpan2 ::: flow_frame_to_absolute_flows(stage_flow_frame )
643+ if (stage == " k1" ) {
644+ stage_before_components = before_components
645+ } else {
646+ stage_before_components = only_iterable(before_components , states )
647+ }
648+ c(stage_before_components , stage_components )
649+ }
650+
651+ # # rk4 step 1
652+ flow_replacements = sprintf(" %s ~ %s" , flows , flow_step_names $ k1 ) | > lapply(as.formula )
653+ k1_new_before = make_before(" k1" )
654+ k1_new_update = macpan2 ::: update_formulas(rate_formulas , flow_replacements )
655+
656+ # # rk4 step 2
657+ state_replacements = sprintf(" %s ~ (%s + (%s / 2))" , states , states , state_step_names $ k1 ) | > lapply(as.formula )
658+ flow_replacements = sprintf(" %s ~ %s" , flows , flow_step_names $ k2 ) | > lapply(as.formula )
659+ k2_new_before = macpan2 ::: update_formulas(make_before(" k2" ), state_replacements )
660+ k2_new_update = sprintf(" %s ~ %s" , state_step_names $ k2 , rates ) | > lapply(as.formula ) | > macpan2 ::: update_formulas(flow_replacements )
661+
662+ # # rk4 step 3
663+ state_replacements = sprintf(" %s ~ (%s + (%s / 2))" , states , states , state_step_names $ k2 ) | > lapply(as.formula )
664+ flow_replacements = sprintf(" %s ~ %s" , flows , flow_step_names $ k3 ) | > lapply(as.formula )
665+ k3_new_before = macpan2 ::: update_formulas(make_before(" k3" ), state_replacements )
666+ k3_new_update = sprintf(" %s ~ %s" , state_step_names $ k3 , rates ) | > lapply(as.formula ) | > macpan2 ::: update_formulas(flow_replacements )
667+
668+ # # rk4 step 4
669+ state_replacements = sprintf(" %s ~ (%s + %s)" , states , states , state_step_names $ k3 ) | > lapply(as.formula )
670+ flow_replacements = sprintf(" %s ~ %s" , flows , flow_step_names $ k4 ) | > lapply(as.formula )
671+ k4_new_before = macpan2 ::: update_formulas(make_before(" k4" ), state_replacements )
672+ k4_new_update = sprintf(" %s ~ %s" , state_step_names $ k4 , rates ) | > lapply(as.formula ) | > macpan2 ::: update_formulas(flow_replacements )
673+
674+ # # final update step
675+ final_flow_update = sprintf(" %s ~ (%s + 2 * %s + 2 * %s + %s)/6"
676+ , flows , flow_step_names $ k1 , flow_step_names $ k2 , flow_step_names $ k3 , flow_step_names $ k4
677+ ) | > lapply(as.formula )
678+ final_state_update = sprintf(" %s ~ %s %s" , states , states , rates ) | > lapply(as.formula ) | > setNames(states )
679+ after_components = self $ change_model $ after_state()
680+ c(
681+ k1_new_before , k1_new_update
682+ , k2_new_before , k2_new_update
683+ , k3_new_before , k3_new_update
684+ , k4_new_before , k4_new_update
685+ , final_flow_update , final_state_update
686+ , after_components
687+ )
688+ }
689+ self $ after = function () self $ change_model $ after_loop()
690+ return_object(self , " EulerUpdateMethod" )
691+ }
583692
584693EulerMultinomialUpdateMethod = function (change_model ) {
585694 self = Base()
@@ -691,9 +800,13 @@ HazardUpdateMethod = function(change_model) {
691800# ' a two-sided formula with the left-hand-side giving the name of the absolute
692801# ' flow rate per unit time-stepand the right-hand-side giving an expression for
693802# ' the per-capita rate of flow from `from` to `to`.
694- # ' @param abs_rate String giving the name for the absolute flow rate,
695- # ' which will be computed as `from * rate`. If a formula is passed to
696- # ' `rate` (not recommended), then this `abs_rate` argument will be ignored.
803+ # ' @param abs_rate String giving the name for the absolute flow rate.
804+ # ' By default, during simulations, the absolute flow rate will be computed as
805+ # ' `from * rate`. This default behaviour will simulate the compartmental model
806+ # ' as discrete difference equations, but this default can be changed to use
807+ # ' other approaches (see \code{\link{state_updates}}).
808+ # ' If a formula is passed to `rate` (not recommended for better readability),
809+ # ' then this `abs_rate` argument will be ignored.
697810# ' @param rate_name String giving the name for the absolute flow rate.
698811# '
699812# ' @examples
0 commit comments