Skip to content

Commit fdc2be1

Browse files
authored
Merge pull request #295 from mrc-ide/mrc-4318
Support for differentiation of parameters in DSL
2 parents 4f9f69e + 5c02f09 commit fdc2be1

File tree

4 files changed

+124
-17
lines changed

4 files changed

+124
-17
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: odin
22
Title: ODE Generation and Integration
3-
Version: 1.5.3
3+
Version: 1.5.4
44
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
55
email = "rich.fitzjohn@gmail.com"),
66
person("Thibaut", "Jombart", role = "ctb"),

R/ir_parse.R

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -479,38 +479,67 @@ ir_parse_packing_internal <- function(names, rank, len, variables,
479479
## few different places. It might be worth trying to shift more of
480480
## this classification into the initial equation parsing.
481481
ir_parse_features <- function(eqs, debug, config, source) {
482-
is_update <- vlapply(eqs, function(x) identical(x$lhs$special, "update"))
483-
is_deriv <- vlapply(eqs, function(x) identical(x$lhs$special, "deriv"))
484-
is_output <- vlapply(eqs, function(x) identical(x$lhs$special, "output"))
485-
is_dim <- vlapply(eqs, function(x) identical(x$lhs$special, "dim"))
482+
is_lhs_update <- vlapply(eqs, function(x) identical(x$lhs$special, "update"))
483+
is_lhs_deriv <- vlapply(eqs, function(x) identical(x$lhs$special, "deriv"))
484+
is_lhs_output <- vlapply(eqs, function(x) identical(x$lhs$special, "output"))
485+
is_lhs_dim <- vlapply(eqs, function(x) identical(x$lhs$special, "dim"))
486486
is_user <- vlapply(eqs, function(x) !is.null(x$user))
487487
is_delay <- vlapply(eqs, function(x) !is.null(x$delay))
488488
is_interpolate <- vlapply(eqs, function(x) !is.null(x$interpolate))
489489
is_stochastic <- vlapply(eqs, function(x) isTRUE(x$stochastic))
490490
is_data <- vlapply(eqs, function(x) !is.null(x$data))
491-
is_compare <- vlapply(eqs, function(x) identical(x$lhs$special, "compare"))
491+
is_lhs_compare <- vlapply(eqs,
492+
function(x) identical(x$lhs$special, "compare"))
493+
is_user_differentiate <- vlapply(eqs,
494+
function(x) isTRUE(x$user$differentiate))
492495

493496
## We'll support other debugging bits later, I imagine.
494497
is_debug_print <- vlapply(debug, function(x) x$type == "print")
495498

496-
if (!any(is_update | is_deriv)) {
499+
if (!any(is_lhs_update | is_lhs_deriv)) {
497500
ir_parse_error("Did not find a deriv() or an update() call",
498501
NULL, NULL)
499502
}
500503

501-
list(continuous = any(is_deriv),
502-
discrete = any(is_update),
503-
mixed = any(is_update) && any(is_deriv),
504-
has_array = any(is_dim),
505-
has_output = any(is_output),
504+
continuous <- any(is_lhs_deriv)
505+
has_compare <- any(is_lhs_compare)
506+
has_array <- any(is_lhs_dim)
507+
has_derivative <- any(is_user_differentiate)
508+
509+
## Most of these constraints go away later, might as well throw them
510+
## early though; we could put it into a preliminary check for
511+
## differentiability but in some ways thast just complicates things.
512+
if (has_derivative) {
513+
if (!has_compare) {
514+
## (this one is fundamental; this just can't be done!
515+
ir_parse_error("You need a compare expression to differentiate!",
516+
ir_parse_error_lines(eqs[is_user_differentiate]), source)
517+
}
518+
if (continuous) {
519+
ir_parse_error("Can't use differentiate with continuous time models",
520+
ir_parse_error_lines(eqs[is_user_differentiate]), source)
521+
}
522+
if (has_array) {
523+
ir_parse_error(
524+
"Can't use differentiate with models that use arrays",
525+
ir_parse_error_lines(eqs[is_user_differentiate | is_lhs_dim]), source)
526+
}
527+
}
528+
529+
list(continuous = continuous,
530+
discrete = any(is_lhs_update),
531+
mixed = any(is_lhs_update) && continuous,
532+
has_array = has_array,
533+
has_output = any(is_lhs_output),
506534
has_user = any(is_user),
507535
has_delay = any(is_delay),
508536
has_interpolate = any(is_interpolate),
509537
has_stochastic = any(is_stochastic),
510538
has_data = any(is_data),
511-
has_compare = any(is_compare),
539+
has_compare = has_compare,
512540
has_include = !is.null(config$include),
513541
has_debug = any(is_debug_print),
542+
has_derivative = has_derivative,
514543
initial_time_dependent = NULL)
515544
}
516545

@@ -1040,7 +1069,9 @@ ir_parse_expr_rhs_user <- function(rhs, line, source) {
10401069
ir_parse_error("Only first argument to user() may be unnamed", line, source)
10411070
}
10421071

1043-
m <- match.call(function(default, integer, min, max, ...) NULL, rhs, FALSE)
1072+
m <- match.call(
1073+
function(default, integer, min, max, differentiate, ...) NULL,
1074+
rhs, FALSE)
10441075
extra <- m[["..."]]
10451076
if (!is.null(extra)) {
10461077
ir_parse_error(sprintf("Unknown %s to user(): %s",
@@ -1063,12 +1094,23 @@ ir_parse_expr_rhs_user <- function(rhs, line, source) {
10631094
if (length(deps$variables) > 0L) {
10641095
ir_parse_error("user() call must not reference variables", line, source)
10651096
}
1066-
## TODO: the 'dim' part here is not actually known yet!
1097+
1098+
integer <- m$integer %||% FALSE
1099+
differentiate <- m$differentiate %||% FALSE
1100+
1101+
if (differentiate && integer) {
1102+
ir_parse_error("Can't differentiate integer parameters",
1103+
line, source)
1104+
}
1105+
1106+
## NOTE: the 'dim' part here is not actually known yet!
10671107
user <- list(default = m$default,
10681108
dim = FALSE,
1069-
integer = m$integer %||% FALSE,
1109+
integer = integer,
10701110
min = m$min,
1071-
max = m$max)
1111+
max = m$max,
1112+
differentiate = differentiate)
1113+
10721114
list(user = user)
10731115
}
10741116

inst/schema.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@
204204
"has_stochastic": { "type": "boolean" },
205205
"has_include": { "type": "boolean" },
206206
"has_debug": { "type": "boolean" },
207+
"has_derivative": { "type": "boolean" },
207208
"initial_time_dependent": { "type": "boolean" }
208209
},
209210
"required": ["discrete", "has_array", "has_output", "has_user",
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
test_that("Can parse with differentiable parameters", {
2+
ir <- odin_parse({
3+
initial(x) <- 1
4+
update(x) <- rnorm(0, 0.1)
5+
d <- data()
6+
compare(d) ~ normal(0, scale)
7+
scale <- user(differentiate = TRUE)
8+
})
9+
10+
d <- ir_deserialise(ir)
11+
expect_true(d$features$has_derivative)
12+
})
13+
14+
15+
test_that("can't differentiate integer parameters", {
16+
expect_error(odin_parse({
17+
initial(x) <- 1
18+
update(x) <- rnorm(0, 0.1)
19+
d <- data()
20+
compare(d) ~ normal(x, scale)
21+
scale <- user(differentiate = TRUE, integer = TRUE)
22+
}),
23+
"Can't differentiate integer parameters\\s+scale <-")
24+
})
25+
26+
27+
test_that("can't differentiate without compare", {
28+
expect_error(
29+
odin_parse({
30+
initial(x) <- 1
31+
update(x) <- rnorm(x, scale)
32+
scale <- user(differentiate = TRUE)
33+
}),
34+
"You need a compare expression to differentiate!\\s+scale <-")
35+
})
36+
37+
38+
test_that("can't differentiate continuous time models", {
39+
expect_error(
40+
odin_parse({
41+
initial(x) <- 1
42+
deriv(x) <- 1
43+
d <- data()
44+
compare(d) ~ normal(x, scale)
45+
scale <- user(differentiate = TRUE)
46+
}),
47+
"Can't use differentiate with continuous time models\\s+scale <-")
48+
})
49+
50+
51+
test_that("can't differentiate models with arrays", {
52+
err <- expect_error(
53+
odin_parse({
54+
initial(x[]) <- 1
55+
update(x[]) <- rnorm(x, 1)
56+
dim(x) <- 5
57+
d <- data()
58+
compare(d) ~ normal(sum(x), scale)
59+
scale <- user(differentiate = TRUE)
60+
}),
61+
"Can't use differentiate with models that use arrays")
62+
expect_match(err$message, "dim(x) <-", fixed = TRUE)
63+
expect_match(err$message, "scale <-", fixed = TRUE)
64+
})

0 commit comments

Comments
 (0)