Skip to content

Commit 29d1a15

Browse files
committed
feat/refactor/fix(distributionregistry): amend model, trajectories, generator and json so that the model actually successfully runs (also improved a docstring in registry) (#11)
1 parent c0b948d commit 29d1a15

File tree

7 files changed

+66
-25
lines changed

7 files changed

+66
-25
lines changed

R/add_patient_generator.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ add_patient_generator <- function(env, trajectory, unit, patient_type, param) {
2929
name_prefix = paste0(unit, "_", patient_type),
3030
trajectory = trajectory,
3131
distribution = function() {
32-
param[["dist"]][["arrivals"]][[unit]][[patient_type]]()
32+
param[["dist"]][["arrival"]][[unit]][[patient_type]]()
3333
}
3434
)
3535
}

R/create_asu_trajectory.R

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,27 @@ create_asu_trajectory <- function(env, patient_type, param) {
3636
}) |>
3737

3838
log_(function() {
39-
dest <- get_attribute(env, "post_asu_destination")
40-
paste0("\U0001F3AF Planned ASU -> ", dest)
39+
dest_num <- get_attribute(env, "post_asu_destination")
40+
dest <- param[["map_num2val"]][as.character(dest_num)]
41+
paste0("\U0001F3AF Planned ASU -> ", dest_num, " (", dest, ")")
4142
}, level = 1L) |>
4243

4344
# Sample ASU LOS. For stroke patients, LOS distribution is based on
4445
# the planned destination after the ASU.
4546
set_attribute("asu_los", function() {
46-
dest <- get_attribute(env, "post_asu_destination")
47+
dest_num <- get_attribute(env, "post_asu_destination")
48+
dest <- param[["map_num2val"]][as.character(dest_num)]
4749
if (patient_type == "stroke") {
4850
switch(
4951
dest,
50-
esd = param[["dest"]][["los"]][["asu"]][["stroke_esd"]],
51-
rehab = param[["dest"]][["los"]][["asu"]][["stroke_no_esd"]],
52-
other = param[["dest"]][["los"]][["asu"]][["stroke_mortality"]],
52+
esd = param[["dist"]][["los"]][["asu"]][["stroke_esd"]](),
53+
rehab = param[["dist"]][["los"]][["asu"]][["stroke_noesd"]](),
54+
other = param[["dist"]][["los"]][["asu"]][["stroke_mortality"]](),
5355
stop("Stroke post-asu destination '", dest, "' invalid",
5456
call. = FALSE)
5557
)
5658
} else {
57-
param[["dest"]][["los"]][["asu"]][[patient_type]]()
59+
param[["dist"]][["los"]][["asu"]][[patient_type]]()
5860
}
5961
}) |>
6062

@@ -72,7 +74,9 @@ create_asu_trajectory <- function(env, patient_type, param) {
7274
# If that patient's destination is rehab, then start on that trajectory
7375
branch(
7476
option = function() {
75-
if (get_attribute(env, "post_asu_destination") == "rehab") 1L else 0L
77+
dest_num <- get_attribute(env, "post_asu_destination")
78+
dest <- param[["map_num2val"]][as.character(dest_num)]
79+
if (dest == "rehab") 1L else 0L
7680
},
7781
continue = FALSE, # Do not continue main trajectory after branch
7882
create_rehab_trajectory(env, patient_type, param)

R/create_rehab_trajectory.R

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,26 @@ create_rehab_trajectory <- function(env, patient_type, param) {
3636
}) |>
3737

3838
log_(function() {
39-
dest <- get_attribute(env, "post_rehab_destination")
40-
paste0("\U0001F3AF Planned rehab -> ", dest_index, " (", dest, ")")
39+
dest_num <- get_attribute(env, "post_rehab_destination")
40+
dest <- param[["map_num2val"]][as.character(dest_num)]
41+
paste0("\U0001F3AF Planned rehab -> ", dest_num, " (", dest, ")")
4142
}, level = 1L) |>
4243

4344
# Sample rehab LOS. For stroke patients, LOS distribution is based on
4445
# the planned destination after the rehab
4546
set_attribute("rehab_los", function() {
46-
dest <- get_attribute(env, "post_rehab_destination")
47+
dest_num <- get_attribute(env, "post_rehab_destination")
48+
dest <- param[["map_num2val"]][as.character(dest_num)]
4749
if (patient_type == "stroke") {
4850
switch(
4951
dest,
50-
esd = param[["los"]][["rehab"]][["stroke_esd"]](),
51-
other = param[["los"]][["rehab"]][["stroke_no_esd"]](),
52+
esd = param[["dist"]][["los"]][["rehab"]][["stroke_esd"]](),
53+
other = param[["dist"]][["los"]][["rehab"]][["stroke_noesd"]](),
5254
stop("Stroke post-rehab destination '", dest, "' invalid",
5355
call. = FALSE)
5456
)
5557
} else {
56-
param[["los"]][["rehab"]][[patient_type]]()
58+
param[["dist"]][["los"]][["rehab"]][[patient_type]]()
5759
}
5860
}) |>
5961

R/distribution_registry.R

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,15 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
7878
#' @description
7979
#' Register a distribution generator under a name.
8080
#'
81-
#' Typically, the generator should be a function that takes
82-
#' distribution-specific parameters and returns a function of `size` (the
83-
#' sample size).
81+
#' Typically, the generator should be a function that:
82+
#' 1. Accepts parameters for a distribution.
83+
#' 2. Returns another function - the *sampler* - which takes a `size`
84+
#' argument and produces that many random values from the specified
85+
#' distribution.
86+
#'
87+
#' By storing generators rather than fixed samplers, you can create as many
88+
#' different samplers as you want later, each with different parameters,
89+
#' while reusing the same generator code.
8490
#'
8591
#' @param name Distribution name (string)
8692
#' @param generator Function to create a sampler given its parameters.
@@ -146,6 +152,8 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
146152
"for a lognormal distribution.")
147153
}
148154
}
155+
# Calls the `get()` method above which finds the distribution generator
156+
# function, then do.call() populates it with dots (a list of arguments).
149157
generator <- self$get(name)
150158
do.call(generator, dots)
151159
},
@@ -162,6 +170,7 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
162170
#' @return List of parameterised samplers (named if config is named).
163171
create_batch = function(config) {
164172
if (is.list(config)) {
173+
# Calls `create()` for each distribution specified in config
165174
lapply(config, function(cfg) {
166175
do.call(self$create, c(cfg$class_name, cfg$params))
167176
})

R/model.R

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,29 @@ model <- function(run_number, param, set_seed = TRUE) {
3232
param[["verbose"]] <- any(c(param[["log_to_console"]],
3333
param[["log_to_file"]]))
3434

35+
# Convert discrete categories from character to numeric (as will store using
36+
# set.attribute(), which doesn't accept strings)
37+
# Identify discrete configs
38+
discrete_cfg <- Filter(\(x) x$class_name == "discrete", param$dist_config)
39+
# Get all the unique "values" (categories)
40+
all_vals <- unique(unlist(lapply(discrete_cfg, \(x) unlist(x$params$values))))
41+
# Build mapping
42+
param[["map_val2num"]] <- setNames(seq_along(all_vals), all_vals)
43+
param[["map_num2val"]] <- setNames(all_vals, seq_along(all_vals))
44+
# Replace discrete values with numeric codes in a copy
45+
param$dist_config_num <- lapply(param$dist_config, \(cfg) {
46+
if (cfg$class_name == "discrete") {
47+
# Flatten, map to numbers, drop names
48+
cfg$params$values <- unname(
49+
param[["map_val2num"]][unlist(cfg$params$values)]
50+
)
51+
}
52+
cfg
53+
})
54+
3555
# Set up sampling distributions
3656
registry <- simulation::DistributionRegistry$new()
37-
param[["dist"]] <- registry$create_batch(as.list(param[["dist_config"]]))
57+
param[["dist"]] <- registry$create_batch(as.list(param[["dist_config_num"]]))
3858

3959
# Restructure as dist[type][unit][patient]
4060
dist <- list()
@@ -61,7 +81,7 @@ model <- function(run_number, param, set_seed = TRUE) {
6181
.env = env, name = paste0(unit, "_bed"), capacity = Inf
6282
)
6383

64-
for (patient_type in names(param[[paste0(unit, "_arrivals")]])) {
84+
for (patient_type in names(param[["dist"]][["arrival"]][[unit]])) {
6585

6686
# Create patient trajectory
6787
traj <- if (unit == "asu") {

inst/extdata/parameters.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@
114114
"rehab_routing_tia": {
115115
"class_name": "discrete",
116116
"params": {
117-
"values": ["other"],
118-
"prob": [1.0]
117+
"values": ["esd", "other"],
118+
"prob": [0.0, 1.0]
119119
}
120120
},
121121
"rehab_routing_neuro": {

man/DistributionRegistry.Rd

Lines changed: 9 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)