Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions R/Design-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,12 @@ setMethod("simulate",
## what is the next dose to be used?
## initialize with starting dose
thisDose <- object@startingDose


# initialize a dataframe to store cohort probabilities
cohort_probs_df <- data.frame(Cohort = integer(), Dose = numeric(), UD = numeric(), TD = numeric(), OD = numeric())
# Initialize cohort index outside the simulation function
cohort_index <- 1

## inside this loop we simulate the whole trial, until stopping
while (!stopit) {
## what is the probability for tox. at this dose?
Expand Down Expand Up @@ -157,15 +162,23 @@ setMethod("simulate",
options = mcmcOptions
)

## => what is the next best dose?
thisDose <- nextBest(object@nextBest,
## => what is the next best dose? And estimated probabilities?
next_best_d <- nextBest(object@nextBest,
doselimit = doselimit,
samples = thisSamples,
model = object@model,
data = thisData
)$value
)
# creating cohort prob matrix
cohort_probs_df <- rbind(cohort_probs_df, cbind(
cohort_index, next_best_d$probs
))

# Increment cohort index after processing all doses in this cohort
cohort_index <- cohort_index + 1

thisDose <- next_best_d$value

## evaluate stopping rules
stopit <- stopTrial(object@stopping,
dose = thisDose,
Expand Down Expand Up @@ -211,7 +224,11 @@ setMethod("simulate",
"message"
),
report_results = stopit_results,
additional_stats = additional_stats
additional_stats = additional_stats,
cohort_probs = {
rownames(cohort_probs_df) <- NULL
cohort_probs_df
}
)
return(thisResult)
}
Expand Down Expand Up @@ -244,6 +261,7 @@ setMethod("simulate",
stop_report = simulations_output$stop_matrix,
stop_reasons = simulations_output$stopReasons,
additional_stats = simulations_output$additional_stats,
cohort_probs = simulations_output$cohort_probs,
seed = RNGstate
)

Expand Down
120 changes: 73 additions & 47 deletions R/Model-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -856,28 +856,27 @@ LogisticKadaneBetaGamma <- function(theta, xmin, xmax, alpha, beta, shape, rate)
#' @description `r lifecycle::badge("stable")`
#'
#' [`LogisticNormalMixture`] is the class for standard logistic regression model
#' with a mixture of two bivariate normal priors on the intercept and slope parameters.
#' with a mixture of k (>1) bivariate normal priors on the intercept and slope parameters.
#'
#' @details The covariate is the natural logarithm of the dose \eqn{x} divided by
#' the reference dose \eqn{x*}, i.e.:
#' \deqn{logit[p(x)] = alpha0 + alpha1 * log(x/x*),}
#' where \eqn{p(x)} is the probability of observing a DLT for a given dose \eqn{x}.
#' The prior
#' \deqn{(alpha0, alpha1) ~ w * Normal(mean1, cov1) + (1 - w) * Normal(mean2, cov2).}
#' The weight w for the first component is assigned a beta prior `B(a, b)`.
#' \deqn{(alpha0, alpha1) ~ w1 * Normal(mean1, cov1) + w2 * Normal(mean2, cov2) + ... + wk * Normal(meank, covk)}
#' The weights (w1, w2, ..., wk) for k components is assigned a dirichlet prior `Dir(a1, a2, ..., ak)`.
#'
#' @note The weight of the two normal priors is a model parameter, hence it is a
#' @note Weights of normal priors is a model parameter, hence it is a
#' flexible mixture. This type of prior is often used with a mixture of a minimal
#' informative and an informative component, in order to make the CRM more robust
#' to data deviations from the informative component.
#'
#' @slot comp1 (`ModelParamsNormal`)\cr bivariate normal prior specification of
#' the first component.
#' @slot comp2 (`ModelParamsNormal`)\cr bivariate normal prior specification of
#' the second component.
#' @slot weightpar (`numeric`)\cr the beta parameters for the weight of the
#' first component. It must a be a named vector of length 2 with names `a` and
#' `b` and with strictly positive values.
#' @slot components (`list`)\cr the specifications of the mixture components,
#' a list with [`ModelParamsNormal`] objects for each bivariate (log) normal
#' prior.
#' @slot weightpars (`numeric`)\cr the dirichlet parameters for weights of
#' components. It is a vector of length k with with strictly positive values.
#' For two components, it is a two-parameter of beta distribution.
#' @slot ref_dose (`positive_number`)\cr the reference dose.
#'
#' @seealso [`ModelParamsNormal`], [`ModelLogNormal`],
Expand All @@ -890,15 +889,16 @@ LogisticKadaneBetaGamma <- function(theta, xmin, xmax, alpha, beta, shape, rate)
Class = "LogisticNormalMixture",
contains = "GeneralModel",
slots = c(
comp1 = "ModelParamsNormal",
comp2 = "ModelParamsNormal",
weightpar = "numeric",
components = "list", # List of (length k) ModelParamsNormal objects
weightpars = "numeric", # Dirichlet parameters (length K)
prior_weights = "numeric", # User defined prior weights
M = "numeric", # Required for prior_weights
ref_dose = "numeric"
),
prototype = prototype(
comp1 = ModelParamsNormal(mean = c(0, 1), cov = diag(2)),
comp2 = ModelParamsNormal(mean = c(-1, 1), cov = diag(2)),
weightpar = c(a = 1, b = 1),
components = list(ModelParamsNormal(mean = c(0, 1), cov = diag(2)),
ModelParamsNormal(mean = c(-1, 1), cov = diag(2))),
weightpars = c(1,1),
ref_dose = 1
),
validity = v_model_logistic_normal_mix
Expand All @@ -908,30 +908,48 @@ LogisticKadaneBetaGamma <- function(theta, xmin, xmax, alpha, beta, shape, rate)

#' @rdname LogisticNormalMixture-class
#'
#' @param comp1 (`ModelParamsNormal`)\cr bivariate normal prior specification of
#' the first component. See [`ModelParamsNormal`] for more details.
#' @param comp2 (`ModelParamsNormal`)\cr bivariate normal prior specification of
#' the second component. See [`ModelParamsNormal`] for more details.
#' @param weightpar (`numeric`)\cr the beta parameters for the weight of the
#' first component. It must a be a named vector of length 2 with names `a` and
#' `b` and with strictly positive values.
#' @param components (`ModelParamsNormal`)\cr the specifications of the mixture components,
#' a list with [`ModelParamsNormal`] objects for each bivariate (log) normal
#' prior. See [`ModelParamsNormal`] for more details.
#' @param weightpars (`numeric`)\cr the dirichlet parameters for the weights of
#' K components. It is a vector of length K with strictly positive values.
#' @param prior_weights Optional length-K vector summing to ~1; if provided with M, weightpars := 1 + M * prior_weights.
#' @param M Optional positive real number concentration used with prior_weights (total concentration = K + M).
#' @param ref_dose (`number`)\cr the reference dose \eqn{x*}
#' (strictly positive number).
#'
#' @export
#' @example examples/Model-class-LogisticNormalMixture.R
#'
LogisticNormalMixture <- function(comp1,
comp2,
weightpar,
LogisticNormalMixture <- function(components,
weightpars = NULL, # optional 1: user can pass Dirichlet alphas directly, if not chosen option 2
prior_weights = NULL, # optional 2: user-supplied (w1,...,wk), sums to 1, if not chosen option 1
M = NULL, # optional 2: concentration par positive real number , if not chosen option 1. Higher the value of M parameters are more cpncentrated. If M->inf, (weights with prior)-> (fixed weights)
ref_dose) {
k <- length(components)
# --- derive weightpars if needed ---
if (is.null(weightpars)) {
stopifnot(!is.null(prior_weights), !is.null(M))
w <- as.numeric(prior_weights)
stopifnot(length(w) == k, all(is.finite(w)), all(w >= 0))
s <- sum(w)
# tolerate tiny floating error and normalize
if (!isTRUE(all.equal(s, 1, tolerance = 1e-8))) w <- w / s

# require positive integer M
stopifnot(length(M) == 1, is.finite(M), M > 0, M == as.integer(M))

# Dirichlet alphas: 1 + M * w
weightpars <- 1 + M * w
} else {
stopifnot(length(weightpars) == k, all(is.finite(weightpars)), all(weightpars > 0))
}
assert_number(ref_dose)

.LogisticNormalMixture(
comp1 = comp1,
comp2 = comp2,
weightpar = weightpar,
ref_dose = ref_dose,
components = components,
weightpars = weightpars,
ref_dose = positive_number(ref_dose),
datamodel = function() {
# The logistic likelihood - the same as for non-mixture case.
for (i in 1:nObs) {
Expand All @@ -940,21 +958,28 @@ LogisticNormalMixture <- function(comp1,
}
},
priormodel = function() {
w ~ dbeta(weightpar[1], weightpar[2])
wc <- 1 - w
comp0 ~ dbern(wc)
comp <- comp0 + 1
# Conditional on the component index "comp", which is 1 or 2.
# comp = 1 with probability "w" and comp = 2 with probability "1 - w".
weights[1:k] ~ ddirch(weightpars[1:k])
comp ~ dcat(weights[1:k])
# Conditional on the component index "comp", which is a integer drawn from (1,2, ..., k).
# comp = 1 with probability "w1", comp = 2 with probability "w2", ..., comp = k with probability "wk".
theta ~ dmnorm(mean[1:2, comp], prec[1:2, 1:2, comp])
alpha0 <- theta[1]
alpha1 <- theta[2]
alpha1 <- exp(theta[2]) # Boehringer's model is logit(pi) = log(alpha)+beta*(d/d*) where log(alpha)=alpha0, beta=alpha1 and (log(alpha),log(beta))~BNM
},
modelspecs = function(from_prior) {
k <- length(components)
# Build mean (2 x k) and prec (2 x 2 x k) for JAGS
mean <- do.call(cbind, lapply(components, function(cmp) cmp@mean))
prec <- array(NA_real_, dim = c(2, 2, k))
for (j in seq_len(k)) {
P <- if (!is.null(components[[j]]@prec)) components[[j]]@prec else solve(components[[j]]@cov)
prec[, , j] <- P
}
ms <- list(
mean = cbind(comp1@mean, comp2@mean),
prec = array(data = c(comp1@prec, comp2@prec), dim = c(2, 2, 2)),
weightpar = weightpar
k = as.integer(k),
weightpars = as.numeric(weightpars),
mean = mean,
prec = prec
)
if (!from_prior) {
ms$ref_dose <- ref_dose
Expand All @@ -965,7 +990,7 @@ LogisticNormalMixture <- function(comp1,
list(theta = c(0, 1))
},
datanames = c("nObs", "y", "x"),
sample = c("alpha0", "alpha1", "w")
sample = c("alpha0", "alpha1", "weights", "comp")
)
}

Expand All @@ -976,15 +1001,16 @@ LogisticNormalMixture <- function(comp1,
#' @export
.DefaultLogisticNormalMixture <- function() { # nolint
LogisticNormalMixture(
comp1 = ModelParamsNormal(
components = list(ModelParamsNormal(
mean = c(-0.85, 1),
cov = matrix(c(1, -0.5, -0.5, 1), nrow = 2)
),
comp2 = ModelParamsNormal(
ModelParamsNormal(
mean = c(1, 1.5),
cov = matrix(c(1.2, -0.45, -0.45, 0.6), nrow = 2)
),
weightpar = c(a = 1, b = 1),
)
),
weightpars = c(1, 1),
ref_dose = 50
)
}
Expand Down Expand Up @@ -1131,7 +1157,7 @@ LogisticNormalFixedMixture <- function(components,
list(theta = c(0, 1))
},
datanames = c("nObs", "y", "x"),
sample = c("alpha0", "alpha1")
sample = c("alpha0", "alpha1", "comp")
)
}

Expand Down
42 changes: 38 additions & 4 deletions R/Model-validity.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,47 @@ v_model_logistic_kadane_beta_gamma <- function(object) { # nolintr
v$result()
}

#' @describeIn v_model_objects validates that `weightpar` is valid.
v_model_logistic_normal_mix <- function(object) {
#' @describeIn v_model_objects validates a K-component Logistic-Normal mixture
#' with Dirichlet weights specified by `weightpars`.
v_model_logistic_normal_mix <- function(object) { # replaces the old Beta check
v <- Validate()
# 1) components: require ModelParamsNormal OR list(mean, prec|cov)
is_mpn <- sapply(object@components, test_class, "ModelParamsNormal")
if (!all(is_mpn)) {
# accept plain lists like list(mean=..., prec=...|cov=...)
list_ok <- all(sapply(object@components, function(z) {
is.list(z) &&
is.numeric(z$mean) && length(z$mean) == 2 &&
(
(!is.null(z$prec) && test_matrix(z$prec, nrows = 2, ncols = 2, any.missing = FALSE) && h_is_positive_definite(z$prec)) ||
(!is.null(z$cov) && test_matrix(z$cov, nrows = 2, ncols = 2, any.missing = FALSE) && h_is_positive_definite(z$cov))
)
}))
v$check(list_ok,
"components must be a list of ModelParamsNormal or lists with numeric mean (length 2) and 2x2 cov/prec (PD)")
} else {
comp_valid_result <- sapply(object@components, validObject, test = TRUE)
v$check(all(sapply(comp_valid_result, isTRUE)),
paste("components must be valid ModelParamsNormal objects",
paste(unlist(comp_valid_result[!sapply(comp_valid_result, isTRUE)]), collapse = ", ")))
}

k <- length(object@components)
v$check(k >= 2, "components must contain at least two mixture components")

# 2) Dirichlet hyperparameters: numeric vector of length K (positive, finite)
v$check(
h_test_named_numeric(object@weightpar, permutation.of = c("a", "b")),
"weightpar must be a named numerical vector of length two with positive finite values and names 'a', 'b'"
test_numeric(object@weightpars, lower = .Machine$double.xmin,
finite = TRUE, any.missing = FALSE, len = k),
"weightpars must be a positive, finite numeric vector of length equal to components (Dirichlet parameters)"
)

# 3) (Optional) ref_dose positive if present in this class
if (!is.null(slotNames(object)) && "ref_dose" %in% slotNames(object)) {
v$check(test_number(object@ref_dose, lower = .Machine$double.xmin, finite = TRUE),
"ref_dose must be a positive, finite number")
}

v$result()
}

Expand Down
Loading