Skip to content

Commit 3e527d5

Browse files
working towards phac workshop:
o track incidence in rk4 o expose expression-list printing utilities o function for computing the change frame o cleaner printing of scalar-valued defaults o bring back special handling of randomness and time-variation with rk4 o expose colour-scheme for box drawing o docs
1 parent 7ce237e commit 3e527d5

File tree

16 files changed

+338
-70
lines changed

16 files changed

+338
-70
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: macpan2
22
Title: Fast and Flexible Compartmental Modelling
3-
Version: 1.12.0
3+
Version: 1.13.0
44
Authors@R: c(
55
person("Steve Walker", email="swalk@mcmaster.ca", role=c("cre", "aut")),
66
person("Weiguang Guan", role="aut"),

NAMESPACE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ S3method(mp_reduce,TMBModelSpec)
7676
S3method(mp_reference,Index)
7777
S3method(mp_reference,Ledger)
7878
S3method(mp_rk4,TMBModelSpec)
79+
S3method(mp_rk4_old,TMBModelSpec)
7980
S3method(mp_simulator,TMBCalibrator)
8081
S3method(mp_simulator,TMBModelSpec)
8182
S3method(mp_simulator,TMBParameterizedModelSpec)
@@ -204,6 +205,7 @@ export(make_expr_parser)
204205
export(mp_absolute_flow)
205206
export(mp_aggregate)
206207
export(mp_cartesian)
208+
export(mp_change_frame)
207209
export(mp_default)
208210
export(mp_default_list)
209211
export(mp_dynamic_model)
@@ -247,11 +249,16 @@ export(mp_per_capita_inflow)
247249
export(mp_per_capita_outflow)
248250
export(mp_poisson)
249251
export(mp_positions)
252+
export(mp_print_after)
253+
export(mp_print_before)
254+
export(mp_print_during)
255+
export(mp_print_spec)
250256
export(mp_rbf)
251257
export(mp_reduce)
252258
export(mp_reference)
253259
export(mp_rename)
254260
export(mp_rk4)
261+
export(mp_rk4_old)
255262
export(mp_set_numbers)
256263
export(mp_setdiff)
257264
export(mp_sim_bounds)

R/flow_frame.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,19 @@ mp_state_vars = function(spec) {
165165
vapply(spec$change_model$update_state(), lhs_char, character(1L))
166166
}
167167

168+
#' Change Frame
169+
#'
170+
#' Get the changes made to each state variable at each time step.
171+
#'
172+
#' @param spec Model specification (\code{\link{mp_tmb_model_spec}}).
173+
#'
174+
#' @return Data frame with two columns: `state` and `change`. Each row
175+
#' describes one change.
176+
#'
177+
#' @export
178+
mp_change_frame = function(spec) spec$change_model$change_frame()
179+
180+
168181

169182
#' Find all Paths
170183
#'

R/formula_list_generators.R

Lines changed: 132 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -317,20 +317,20 @@ MockChangeModel = function() {
317317

318318
##' State Updates
319319
##'
320-
##' Use these functions to update a model spec so that the state variables
321-
##' are updated according to different numerical methods.
320+
##' These functions return a modified version of a model specification, such
321+
##' that the state variables are updated each time step according to different
322+
##' numerical methods.
323+
##'
324+
##' To see the computations that update the state variables under these
325+
##' modified specifications, one may use the
326+
##' \code{\link{mp_expand}} function (see examples).
322327
##'
323328
##' The default update method for model specifications produced using
324329
##' \code{\link{mp_tmb_model_spec}} is `mp_euler`. This update method
325330
##' yields a difference-equation model where the state is updated once
326331
##' per time-step using the absolute flow rate as the difference between
327332
##' steps.
328333
##'
329-
##' These state update functions are used to modify a model specification to
330-
##' use a particular kind of state update. To see these modified models for
331-
##' a particular example one may use the \code{\link{mp_expand}} function
332-
##' (see examples).
333-
##'
334334
##' @param model Object with quantities that have been explicitly
335335
##' marked as state variables.
336336
##'
@@ -341,10 +341,18 @@ MockChangeModel = function() {
341341
##' sir |> mp_rk4() |> mp_expand()
342342
##' sir |> mp_euler_multinomial() |> mp_expand()
343343
##'
344+
##' @name state_updates
345+
NULL
346+
347+
##' @describeIn state_updates ODE solver using the Euler method, which is
348+
##' equivalent to treating the model as a set of discrete-time difference
349+
##' equations. This is the default method used by
350+
##' \code{\link{mp_tmb_model_spec}}, but this default can be changed using
351+
##' the functions described below.
344352
##' @export
345353
mp_euler = function(model) UseMethod("mp_euler")
346354

347-
##' @describeIn mp_euler ODE solver using Runge-Kutta 4. Any formulas that
355+
##' @describeIn state_updates ODE solver using Runge-Kutta 4. Any formulas that
348356
##' appear before model flows in the `during` list will only be updated
349357
##' with RK4 if they do contain functions in
350358
##' `getOption("macpan2_non_iterable_funcs")` and if they do not make any
@@ -362,17 +370,24 @@ mp_euler = function(model) UseMethod("mp_euler")
362370
##' will only be called once per time-step, and so it should never be removed
363371
##' from the list of non-iterable functions. Although in principle it could
364372
##' make sense to update state variables manually, it currently causes us to
365-
##' be confused. We therefore require that all state variables updates are set
366-
##' explicitly (e.g., with \code{\link{mp_per_capita_flow}}) if any are explicit.
373+
##' be confused. We therefore require that all state variable updates are set
374+
##' explicitly (e.g., with \code{\link{mp_per_capita_flow}}).
367375
##' @export
368376
mp_rk4 = function(model) UseMethod("mp_rk4")
369377

370-
##' @describeIn mp_euler Update state with process error given by the
378+
##' @describeIn state_updates Old version of `mp_rk4` that doesn't keep track
379+
##' of absolute flows through each time-step. As a result this version is
380+
##' more efficient but makes it more difficult to compute things like
381+
##' incidence over a time scale.
382+
##' @export
383+
mp_rk4_old = function(model) UseMethod("mp_rk4_old")
384+
385+
##' @describeIn state_updates Update state with process error given by the
371386
##' Euler-multinomial distribution.
372387
##' @export
373388
mp_euler_multinomial = function(model) UseMethod("mp_euler_multinomial")
374389

375-
##' @describeIn mp_euler Update state with hazard steps, which is equivalent
390+
##' @describeIn state_updates Update state with hazard steps, which is equivalent
376391
##' to taking the step given by the expected value of the Euler-multinomial
377392
##' distribution.
378393
##' @export
@@ -384,6 +399,10 @@ mp_euler.TMBModelSpec = function(model) model$change_update_method("euler")
384399
##' @export
385400
mp_rk4.TMBModelSpec = function(model) model$change_update_method("rk4")
386401

402+
##' @export
403+
mp_rk4_old.TMBModelSpec = function(model) model$change_update_method("rk4_old")
404+
405+
387406
##' @export
388407
mp_euler_multinomial.TMBModelSpec = function(model) model$change_update_method("euler_multinomial")
389408

@@ -446,6 +465,7 @@ get_state_update_method = function(state_update, change_model) {
446465
}
447466
cls_nm = sprintf("%sUpdateMethod", var_case_to_cls_case(state_update))
448467
if (state_update == "rk4") cls_nm = "RK4UpdateMethod"
468+
if (state_update == "rk4_old") cls_nm = "RK4OldUpdateMethod"
449469
get(cls_nm)(change_model)
450470
}
451471
get_change_model = function(before, during, after) {
@@ -490,9 +510,7 @@ EulerUpdateMethod = function(change_model, existing_global_names = character())
490510
self = UpdateMethod()
491511
self$change_model = change_model
492512

493-
## nb: euler method requires no additional names from the spec
494-
#self$existing_global_names = existing_global_names
495-
513+
## euler method requires no additional names from the spec
496514

497515
self$before = function() self$change_model$before_loop()
498516
self$during = function() {
@@ -517,7 +535,7 @@ EulerUpdateMethod = function(change_model, existing_global_names = character())
517535
}
518536

519537

520-
RK4UpdateMethod = function(change_model) {
538+
RK4OldUpdateMethod = function(change_model) {
521539
self = UpdateMethod()
522540
self$change_model = change_model
523541

@@ -580,6 +598,97 @@ RK4UpdateMethod = function(change_model) {
580598
return_object(self, "EulerUpdateMethod")
581599
}
582600

601+
RK4UpdateMethod = function(change_model) {
602+
self = UpdateMethod()
603+
self$change_model = change_model
604+
605+
self$before = function() self$change_model$before_loop()
606+
self$during = function() {
607+
before_components = self$change_model$before_flows()
608+
flow_frame = self$change_model$flow_frame()
609+
components = flow_frame_to_absolute_flows(flow_frame)
610+
before_state = self$change_model$before_state()
611+
# before = c(before_components, components, before_state)
612+
update_state = self$change_model$update_state()
613+
update_flows = self$change_model$update_flows() |> unlist(recursive = FALSE, use.names = FALSE)
614+
615+
new_update = list()
616+
new_before = list()
617+
618+
states = vapply(update_state, lhs_char, character(1L))
619+
rates = vapply(update_state, rhs_char, character(1L))
620+
flows = flow_frame$change
621+
622+
existing_names = self$change_model$all_user_aware_names()
623+
local_state_step_names = list(
624+
k1 = sprintf("k1_%s", states)
625+
, k2 = sprintf("k2_%s", states)
626+
, k3 = sprintf("k3_%s", states)
627+
, k4 = sprintf("k4_%s", states)
628+
)
629+
local_flow_step_names = list(
630+
k1 = sprintf("k1_%s", flows)
631+
, k2 = sprintf("k2_%s", flows)
632+
, k3 = sprintf("k3_%s", flows)
633+
, k4 = sprintf("k4_%s", flows)
634+
)
635+
state_step_names = map_names(existing_names, local_state_step_names)
636+
flow_step_names = map_names(existing_names, local_flow_step_names)
637+
638+
rate_formulas = sprintf("%s ~ %s", state_step_names$k1, rates) |> lapply(as.formula)
639+
640+
make_before = function(stage) {
641+
stage_flow_frame = within(flow_frame, change <- flow_step_names[[stage]])
642+
stage_components = macpan2:::flow_frame_to_absolute_flows(stage_flow_frame)
643+
if (stage == "k1") {
644+
stage_before_components = before_components
645+
} else {
646+
stage_before_components = only_iterable(before_components, states)
647+
}
648+
c(stage_before_components, stage_components)
649+
}
650+
651+
## rk4 step 1
652+
flow_replacements = sprintf("%s ~ %s", flows, flow_step_names$k1) |> lapply(as.formula)
653+
k1_new_before = make_before("k1")
654+
k1_new_update = macpan2:::update_formulas(rate_formulas, flow_replacements)
655+
656+
## rk4 step 2
657+
state_replacements = sprintf("%s ~ (%s + (%s / 2))", states, states, state_step_names$k1) |> lapply(as.formula)
658+
flow_replacements = sprintf("%s ~ %s", flows, flow_step_names$k2) |> lapply(as.formula)
659+
k2_new_before = macpan2:::update_formulas(make_before("k2"), state_replacements)
660+
k2_new_update = sprintf("%s ~ %s", state_step_names$k2, rates) |> lapply(as.formula) |> macpan2:::update_formulas(flow_replacements)
661+
662+
## rk4 step 3
663+
state_replacements = sprintf("%s ~ (%s + (%s / 2))", states, states, state_step_names$k2) |> lapply(as.formula)
664+
flow_replacements = sprintf("%s ~ %s", flows, flow_step_names$k3) |> lapply(as.formula)
665+
k3_new_before = macpan2:::update_formulas(make_before("k3"), state_replacements)
666+
k3_new_update = sprintf("%s ~ %s", state_step_names$k3, rates) |> lapply(as.formula) |> macpan2:::update_formulas(flow_replacements)
667+
668+
## rk4 step 4
669+
state_replacements = sprintf("%s ~ (%s + %s)", states, states, state_step_names$k3) |> lapply(as.formula)
670+
flow_replacements = sprintf("%s ~ %s", flows, flow_step_names$k4) |> lapply(as.formula)
671+
k4_new_before = macpan2:::update_formulas(make_before("k4"), state_replacements)
672+
k4_new_update = sprintf("%s ~ %s", state_step_names$k4, rates) |> lapply(as.formula) |> macpan2:::update_formulas(flow_replacements)
673+
674+
## final update step
675+
final_flow_update = sprintf("%s ~ (%s + 2 * %s + 2 * %s + %s)/6"
676+
, flows, flow_step_names$k1, flow_step_names$k2, flow_step_names$k3, flow_step_names$k4
677+
) |> lapply(as.formula)
678+
final_state_update = sprintf("%s ~ %s %s", states, states, rates) |> lapply(as.formula) |> setNames(states)
679+
after_components = self$change_model$after_state()
680+
c(
681+
k1_new_before, k1_new_update
682+
, k2_new_before, k2_new_update
683+
, k3_new_before, k3_new_update
684+
, k4_new_before, k4_new_update
685+
, final_flow_update, final_state_update
686+
, after_components
687+
)
688+
}
689+
self$after = function() self$change_model$after_loop()
690+
return_object(self, "EulerUpdateMethod")
691+
}
583692

584693
EulerMultinomialUpdateMethod = function(change_model) {
585694
self = Base()
@@ -691,9 +800,13 @@ HazardUpdateMethod = function(change_model) {
691800
#' a two-sided formula with the left-hand-side giving the name of the absolute
692801
#' flow rate per unit time-stepand the right-hand-side giving an expression for
693802
#' the per-capita rate of flow from `from` to `to`.
694-
#' @param abs_rate String giving the name for the absolute flow rate,
695-
#' which will be computed as `from * rate`. If a formula is passed to
696-
#' `rate` (not recommended), then this `abs_rate` argument will be ignored.
803+
#' @param abs_rate String giving the name for the absolute flow rate.
804+
#' By default, during simulations, the absolute flow rate will be computed as
805+
#' `from * rate`. This default behaviour will simulate the compartmental model
806+
#' as discrete difference equations, but this default can be changed to use
807+
#' other approaches (see \code{\link{state_updates}}).
808+
#' If a formula is passed to `rate` (not recommended for better readability),
809+
#' then this `abs_rate` argument will be ignored.
697810
#' @param rate_name String giving the name for the absolute flow rate.
698811
#'
699812
#' @examples

R/lists.R

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,24 @@ melt_matrix = function(x, zeros_are_blank = TRUE) {
8383
data.frame(row = row, col = col, value = as.vector(x))
8484
}
8585

86-
melt_default_matrix_list = function(x, zeros_are_blank = TRUE) {
86+
melt_default_matrix_list = function(x, zeros_are_blank = TRUE, simplify_as_scalars = FALSE) {
8787
if (length(x) == 0L) return(NULL)
8888
f = (x
8989
|> lapply(melt_matrix, zeros_are_blank)
9090
|> bind_rows(.id = "matrix")
9191
)
92+
if (simplify_as_scalars) {
93+
rm_rs = all(f$row == "")
94+
rm_cs = all(f$col == "")
95+
if (rm_rs) f$row = NULL
96+
if (rm_cs) f$col = NULL
97+
if (rm_rs & rm_cs) {
98+
nms = colnames(f)
99+
mat_col = nms == "matrix"
100+
if (any(mat_col)) names(f)[mat_col] = "quantity"
101+
}
102+
}
103+
92104
rownames(f) = NULL
93105
f
94106
}

R/mp_tmb_model_spec.R

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ TMBModelSpec = function(
77
, must_save = character()
88
, must_not_save = character()
99
, sim_exprs = character()
10-
, state_update = c("euler", "rk4", "euler_multinomial", "hazard")
10+
, state_update = c("euler", "rk4", "euler_multinomial", "hazard", "rk4_old")
1111
) {
1212
must_not_save = handle_saving_conflicts(must_save, must_not_save)
1313
self = Base()
@@ -104,7 +104,7 @@ TMBModelSpec = function(
104104
)
105105
}
106106
self$change_update_method = function(
107-
state_update = c("euler", "rk4", "euler_multinomial", "hazard")
107+
state_update = c("euler", "rk4", "euler_multinomial", "hazard", "rk4_old")
108108
) {
109109

110110
if (self$state_update == "no") {
@@ -296,21 +296,54 @@ must_save_time_args = function(formulas) {
296296
mp_tmb_model_spec = TMBModelSpec
297297

298298
#' @export
299-
print.TMBModelSpec = function(x, ...) {
300-
spec_printer(x, include_defaults = TRUE)
301-
}
299+
print.TMBModelSpec = function(x, ...) mp_print_spec(x)
302300

303301
spec_printer = function(x, include_defaults) {
304-
#e = ExprList(x$before, x$during, x$after)
305-
#e = x$expr_list()
306302
if (include_defaults) {
307303
cat("---------------------\n")
308304
msg("Default values:\n") |> cat()
309305
cat("---------------------\n")
310-
print(melt_default_matrix_list(x$default), row.names = FALSE)
306+
print(melt_default_matrix_list(x$default, simplify_as_scalars = TRUE), row.names = FALSE)
311307
cat("\n")
312308
}
313309
exprs = c(x$before, x$during, x$after)
314310
schedule = c(length(x$before), length(x$during), length(x$after))
315311
model_steps_printer(exprs, schedule)
316312
}
313+
314+
#' Print Model Specification
315+
#'
316+
#' @param model A model produced by \code{\link{mp_tmb_model_spec}}.
317+
#'
318+
#' @export
319+
mp_print_spec = function(model) spec_printer(model, include_defaults = TRUE)
320+
321+
#' @describeIn mp_print_spec Print just the expressions executed before the
322+
#' simulation loop.
323+
#' @export
324+
mp_print_before = function(model) {
325+
model_steps_printer(
326+
model$before
327+
, c(length(model$before), 0L, 0L)
328+
)
329+
}
330+
331+
#' @describeIn mp_print_spec Print just the expressions executed during each
332+
#' iteration of the simulation loop.
333+
#' @export
334+
mp_print_during = function(model) {
335+
model_steps_printer(
336+
model$during
337+
, c(0L, length(model$during), 0L)
338+
)
339+
}
340+
341+
#' @describeIn mp_print_spec Print just the expressions executed after the
342+
#' simulation loop.
343+
#' @export
344+
mp_print_after = function(model) {
345+
model_steps_printer(
346+
model$after
347+
, c(0L, 0L, length(model$after))
348+
)
349+
}

0 commit comments

Comments
 (0)