Skip to content

Commit c49cc4a

Browse files
committed
fix(test): fixed tests to work with new syntax from distributionregistry (#11)
1 parent b51d3de commit c49cc4a

File tree

3 files changed

+133
-72
lines changed

3 files changed

+133
-72
lines changed

tests/testthat/test-backtest.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
test_that("results from a new run match those previously generated", {
77
# Run the model for 2 replications
8-
param <- create_parameters(cores = 1L, number_of_runs = 2L)
8+
param <- parameters(cores = 1L, number_of_runs = 2L)
99
results <- runner(param = param)
1010

1111
# Extract and format the results (e.g. sort, dataframe, column type)

tests/testthat/test-functionaltest.R

Lines changed: 129 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,122 +4,190 @@
44
# functionality.
55

66

7+
# -----------------------------------------------------------------------------
8+
# Helper function
9+
# -----------------------------------------------------------------------------
10+
11+
12+
#' Update one or more probabilities in a routing parameter list.
13+
#'
14+
#' @param param The full model parameters list, as returned by [parameters()].
15+
#' @param routing_name Character string naming the routing block within
16+
#' `param$dist_config` (e.g. `"asu_routing_tia"`).
17+
#' @param updates Named numeric vector or list, where names are route names
18+
#' and values are the new probabilities to set.
19+
#'
20+
#' @return The modified `params_list` with the updated probability.
21+
22+
update_routing_prob <- function(param, routing_name, updates) {
23+
if (!routing_name %in% names(param$dist_config)) {
24+
stop(sprintf("Routing '%s' not found in param$dist_config", routing_name),
25+
call. = FALSE)
26+
}
27+
28+
params_list <- param$dist_config[[routing_name]]$params
29+
30+
if (is.null(names(updates)) || any(names(updates) == "")) {
31+
stop("'updates' must be a named vector or list", call. = FALSE)
32+
}
33+
34+
for (route in names(updates)) {
35+
idx <- which(params_list$values == route)
36+
if (length(idx) != 1L) {
37+
stop(sprintf(
38+
"Expected exactly one match for route '%s', found %d",
39+
route, length(idx)
40+
), call. = FALSE)
41+
}
42+
params_list$prob[[idx]] <- updates[[route]]
43+
}
44+
45+
param$dist_config[[routing_name]]$params <- params_list
46+
param
47+
}
48+
49+
750
# -----------------------------------------------------------------------------
851
# 1. Parameter validation
952
# -----------------------------------------------------------------------------
1053

1154
test_that("model errors for invalid asu_arrivals values", {
12-
param <- create_parameters()
55+
param <- parameters()
1356
# Negative value for stroke
14-
param$asu_arrivals$stroke <- -1L
57+
param$dist_config$asu_arrival_stroke$params$mean <- -1L
1558
expect_error(
1659
model(param = param, run_number = 1L),
17-
'All values in "asu_arrivals" must be greater than 0.'
60+
'All values in "asu_arrival_stroke$params$mean" must be greater than 0.',
61+
fixed = TRUE
1862
)
1963
# Zero value for neuro
20-
param <- create_parameters()
21-
param$asu_arrivals$neuro <- 0L
64+
param <- parameters()
65+
param$dist_config$asu_arrival_neuro$params$mean <- 0L
2266
expect_error(
2367
model(param = param, run_number = 1L),
24-
'All values in "asu_arrivals" must be greater than 0.'
68+
'All values in "asu_arrival_neuro$params$mean" must be greater than 0.',
69+
fixed = TRUE
2570
)
2671
})
2772

2873

2974
test_that("model errors for invalid asu_los values", {
30-
param <- create_parameters()
3175
# Negative mean for stroke_no_esd
32-
param$asu_los$stroke_no_esd$mean <- -5L
76+
param <- parameters()
77+
param$dist_config$asu_los_stroke_noesd$params$mean <- -5L
3378
expect_error(
3479
model(param = param, run_number = 1L),
35-
'All values in "asu_los" must be greater than 0.'
80+
'All values in "asu_los_stroke_noesd$params$mean" must be greater than 0.',
81+
fixed = TRUE
3682
)
3783
# Zero sd for tia
38-
param <- create_parameters()
39-
param$asu_los$tia$sd <- 0L
84+
param <- parameters()
85+
param$dist_config$asu_los_tia$params$sd <- 0L
4086
expect_error(
4187
model(param = param, run_number = 1L),
42-
'All values in "asu_los" must be greater than 0.'
88+
'All values in "asu_los_tia$params$sd" must be greater than 0.',
89+
fixed = TRUE
4390
)
4491
})
4592

4693

4794
test_that("model errors for invalid asu_routing probabilities", {
48-
param <- create_parameters()
95+
param <- parameters()
4996
# Non-numeric value
50-
param$asu_routing$stroke$rehab <- "a"
97+
param <- update_routing_prob(param, "asu_routing_stroke", c("rehab" = "a"))
5198
expect_error(
5299
model(param = param, run_number = 1L),
53-
'Routing vector "asu_routing$stroke" must be numeric.',
100+
'Routing vector "asu_routing_stroke$params$prob" must be numeric.',
54101
fixed = TRUE
55102
)
56103
# Probability out of bounds
57-
param <- create_parameters()
58-
param$asu_routing$stroke$rehab <- -0.1
104+
param <- parameters()
105+
param <- update_routing_prob(param, "asu_routing_stroke", c("rehab" = -0.1))
59106
expect_error(
60107
model(param = param, run_number = 1L),
61-
'All values in routing vector "asu_routing$stroke" must be between 0 and 1.', # nolint: line_length_linter
108+
'All values in routing vector "asu_routing_stroke$params$prob" must be between 0 and 1.', # nolint: line_length_linter
62109
fixed = TRUE
63110
)
64111
# Probabilities do not sum to 1
65-
param <- create_parameters()
66-
param$asu_routing$stroke$rehab <- 0.5
67-
param$asu_routing$stroke$esd <- 0.5
68-
param$asu_routing$stroke$other <- 0.5
112+
param <- parameters()
113+
param <- update_routing_prob(param, "asu_routing_stroke",
114+
c("rehab" = 0.5, "esd" = 0.5, "other" = 0.5))
69115
expect_error(
70116
model(param = param, run_number = 1L),
71-
'Values in routing vector "asu_routing$stroke" must sum to 1 (+-0.01).',
117+
'Values in routing vector "asu_routing_stroke$params$prob" must sum to 1 (+-0.01).', # nolint: line_length_linter
72118
fixed = TRUE
73119
)
74120
})
75121

76122

77123
test_that("model errors for invalid rehab_routing probabilities", {
78-
param <- create_parameters()
79-
param$rehab_routing$other$esd <- 1.5
124+
# Probabilities should be within 0 and 1
125+
param <- parameters()
126+
param <- update_routing_prob(param, "rehab_routing_other", c("esd" = 1.5))
80127
expect_error(
81128
model(param = param, run_number = 1L),
82-
'All values in routing vector "rehab_routing$other" must be between 0 and 1.', # nolint: line_length_linter
129+
'All values in routing vector "rehab_routing_other$params$prob" must be between 0 and 1.', # nolint: line_length_linter
83130
fixed = TRUE
84131
)
85-
# Probabilities do not sum to 1
86-
param <- create_parameters()
87-
param$rehab_routing$stroke$esd <- 0.8
88-
param$rehab_routing$stroke$other <- 0.3
89-
expect_error(
90-
model(param = param, run_number = 1L),
91-
'Values in routing vector "rehab_routing$stroke" must sum to 1 (+-0.01).',
92-
fixed = TRUE
93-
)
94-
})
95132

96-
97-
test_that("model errors for missing keys in asu_los", {
98-
param <- create_parameters()
99-
param$asu_los$other <- NULL # Remove required key
133+
# Probabilities should sum to 1
134+
param <- parameters()
135+
param <- update_routing_prob(param, "rehab_routing_stroke",
136+
c("esd"=0.8, "other"=0.3))
100137
expect_error(
101138
model(param = param, run_number = 1L),
102-
"Missing keys: other."
139+
'Values in routing vector "rehab_routing_stroke$params$prob" must sum to 1 (+-0.01)', # nolint: line_length_linter
140+
fixed = TRUE
103141
)
104142
})
105143

106144

107-
test_that("model errors for extra keys in asu_arrivals", {
108-
param <- create_parameters()
109-
param$asu_arrivals$extra <- 5L # Add unexpected key
110-
expect_error(
111-
model(param = param, run_number = 1L),
112-
"Extra keys: extra."
145+
patrick::with_parameters_test_that(
146+
"model errors for invalid/missing/extra keys in parameters",
147+
{
148+
param <- parameters()
149+
param <- mod(param)
150+
expect_error(model(run_number = 0L, param = param), msg, fixed = TRUE)
151+
},
152+
patrick::cases(
153+
missing_number_of_runs = list(
154+
mod = function(p) { p$number_of_runs <- NULL; p },
155+
msg = "Problem in param. Missing: number_of_runs. Extra: ."
156+
),
157+
# Missing key in param$dist_config
158+
missing_rehab_arrival_neuro = list(
159+
mod = function(p) { p$dist_config$rehab_arrival_neuro <- NULL; p },
160+
msg = "Problem in param$dist_config. Missing: rehab_arrival_neuro. Extra: ." # nolint: line_length_linter
161+
),
162+
# Missing specific dist_config key
163+
missing_rehab_los_tia = list(
164+
mod = function(p) { p$dist_config$rehab_los_tia$params <- NULL; p },
165+
msg = "Missing required parameter(s) in param$dist_configrehab_los_tia: params. Allowed: class_name, params" # nolint: line_length_linter
166+
),
167+
# Extra key in top-level param
168+
extra_top_level = list(
169+
mod = function(p) { p$extra_key <- 5L; p },
170+
msg = "Problem in param. Missing: . Extra: extra_key."
171+
),
172+
# Extra key in param$dist_config
173+
extra_in_dist_config = list(
174+
mod = function(p) { p$dist_config$extra_key <- 5L; p },
175+
msg = "Problem in param$dist_config. Missing: . Extra: extra_key."
176+
),
177+
# Extra key in nested dist_config entry
178+
extra_in_asu_arrival_stroke = list(
179+
mod = function(p) { p$dist_config$asu_arrival_stroke$extra_key <- 5L; p },
180+
msg = "Unrecognised parameter(s) in param$dist_configasu_arrival_stroke: extra_key. Allowed: class_name, params" # nolint: line_length_linter
181+
)
113182
)
114-
})
115-
183+
)
116184

117185
# -----------------------------------------------------------------------------
118186
# 2. Run results
119187
# -----------------------------------------------------------------------------
120188

121189
test_that("values are non-negative and not NA", {
122-
param <- create_parameters(
190+
param <- parameters(
123191
warm_up_period = 20L, data_collection_period = 20L,
124192
cores = 1L, number_of_runs = 1L
125193
)
@@ -145,32 +213,25 @@ patrick::with_parameters_test_that(
145213
"adjusting parameters decreases arrivals",
146214
{
147215
# Set some defaults
148-
default_param <- create_parameters(
216+
default_param <- parameters(
149217
warm_up_period = 100L, data_collection_period = 200L,
150218
cores = 1L, number_of_runs = 1L
151219
)
152220

153221
# Set up parameter sets
154222
init_param <- default_param
155223
adj_param <- default_param
156-
if (is.null(metric)) {
157-
init_param[[group]][[patient]] <- init_value
158-
adj_param[[group]][[patient]] <- adj_value
159-
} else {
160-
init_param[[group]][[patient]][[metric]] <- init_value
161-
adj_param[[group]][[patient]][[metric]] <- adj_value
162-
}
224+
init_param$dist_config[[group]]$params$mean <- init_value
225+
adj_param$dist_config[[group]]$params$mean <- adj_value
163226

164227
# Run model and compare number of arrivals
165228
init_arrivals <- nrow(runner(param = init_param)[["arrivals"]])
166229
adj_arrivals <- nrow(runner(param = adj_param)[["arrivals"]])
167230
expect_gt(init_arrivals, adj_arrivals)
168231
},
169232
patrick::cases(
170-
list(group = "asu_arrivals", patient = "stroke", metric = NULL,
171-
init_value = 2L, adj_value = 6L),
172-
list(group = "rehab_los", patient = "stroke_no_esd", metric = "mean",
173-
init_value = 30L, adj_value = 10L)
233+
list(group = "asu_arrival_stroke", init_value = 2L, adj_value = 6L),
234+
list(group = "rehab_los_stroke_noesd", init_value = 30L, adj_value = 10L)
174235
)
175236
)
176237

@@ -181,7 +242,7 @@ patrick::with_parameters_test_that(
181242

182243
test_that("the same seed returns the same result", {
183244

184-
param <- create_parameters(
245+
param <- parameters(
185246
warm_up_period = 20L, data_collection_period = 20L,
186247
cores = 1L, number_of_runs = 3L
187248
)
@@ -204,7 +265,7 @@ test_that("the same seed returns the same result", {
204265

205266
test_that("model and runner produce same results if override future.seed", {
206267

207-
param <- create_parameters(
268+
param <- parameters(
208269
warm_up_period = 20L, data_collection_period = 20L,
209270
cores = 1L, number_of_runs = 3L
210271
)
@@ -233,7 +294,7 @@ test_that("model and runner produce same results if override future.seed", {
233294
test_that("results are as expected if model runs with only a warm-up", {
234295

235296
# Run with only warm-up and no data collection period
236-
param <- create_parameters(
297+
param <- parameters(
237298
warm_up_period = 100L, data_collection_period = 0L,
238299
cores = 1L, number_of_runs = 1L
239300
)
@@ -252,7 +313,7 @@ test_that("results are as expected if model runs with only a warm-up", {
252313

253314
test_that("running with warm-up leads to different results than without", {
254315
# Run without warm-up, expect first audit to have time and occupancy of 0
255-
param <- create_parameters(
316+
param <- parameters(
256317
warm_up_period = 0L, data_collection_period = 20L,
257318
cores = 1L, number_of_runs = 1L
258319
)
@@ -264,7 +325,7 @@ test_that("running with warm-up leads to different results than without", {
264325
expect_true(all(first_audit[["occupancy"]] == 0L))
265326

266327
# Run with warm-up, expect first audit to have time and occupancy > 0
267-
param <- create_parameters(
328+
param <- parameters(
268329
warm_up_period = 50L, data_collection_period = 20L,
269330
cores = 1L, number_of_runs = 1L
270331
)
@@ -284,7 +345,7 @@ test_that("running with warm-up leads to different results than without", {
284345
test_that("log to console and file work correctly", {
285346
# Set parameters and create temporary file for log
286347
log_file <- tempfile(fileext = ".log")
287-
param <- create_parameters(
348+
param <- parameters(
288349
warm_up_period = 0L,
289350
data_collection_period = 20L,
290351
log_to_console = TRUE,

tests/testthat/test-unittest.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ test_that("create returns a sampler that samples correctly", {
113113
sampler <- reg$create("normal", mean = 10L, sd = 2L)
114114
samples <- sampler(size = 5L)
115115
expect_length(samples, 5L)
116-
expect_type(samples, "numeric")
116+
expect_type(samples, "double")
117117
})
118118

119119
test_that("register adds and retrieves custom distribution", {
@@ -133,6 +133,6 @@ test_that("create_batch creates multiple samplers", {
133133
expect_length(batch, 2L)
134134
expect_type(batch[[1L]], "closure")
135135
expect_type(batch[[2L]], "closure")
136-
expect_type(batch[[1L]](2L), "numeric")
137-
expect_true(batch[[2L]](2L), "numeric")
136+
expect_type(batch[[1L]](size = 2L), "double")
137+
expect_type(batch[[2L]](size = 2L), "integer")
138138
})

0 commit comments

Comments
 (0)