Skip to content

Commit 2309b2c

Browse files
authored
estimate_contrasts() for estimate_relation() etc (#372)
* `estimate_contrasts()` for `estimate_relation()` etc * ... * adjustements * fix, styler * version * ... * ... * fix * Update estimate_contrast_methods.R * Update estimate_contrast_methods.R * Update estimate_contrast_methods.R * Update estimate_contrast_methods.R * add tests * add test * add test * tests * Update estimate_contrast_methods.R * use validate_arg * fix test * fix * no pipe * fix * fix? * tests only work interactively * fix
1 parent be0f8d5 commit 2309b2c

11 files changed

+651
-64
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Type: Package
22
Package: modelbased
33
Title: Estimation of Model-Based Predictions, Contrasts and Means
4-
Version: 0.8.9.106
4+
Version: 0.8.9.107
55
Authors@R:
66
c(person(given = "Dominique",
77
family = "Makowski",

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
S3method(describe_nonlinear,data.frame)
44
S3method(describe_nonlinear,estimate_predicted)
55
S3method(describe_nonlinear,numeric)
6+
S3method(estimate_contrasts,default)
7+
S3method(estimate_contrasts,estimate_predicted)
68
S3method(format,estimate_contrasts)
79
S3method(format,estimate_grouplevel)
810
S3method(format,estimate_means)

R/estimate_contrast_methods.R

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
#' @export
2+
estimate_contrasts.estimate_predicted <- function(model,
3+
contrast = NULL,
4+
by = NULL,
5+
predict = "response",
6+
ci = 0.95,
7+
p_adjust = "none",
8+
comparison = "pairwise",
9+
verbose = TRUE,
10+
...) {
11+
# sanity check
12+
if (inherits(comparison, "formula")) {
13+
comparison <- all.vars(comparison)[1]
14+
}
15+
comparison <- insight::validate_argument(comparison, c("pairwise", "interaction"))
16+
17+
# sanity check
18+
if (is.null(contrast)) {
19+
insight::format_error("Argument `contrast` must be specified and cannot be `NULL`.")
20+
}
21+
22+
# the "model" object is an object of class "estimate_predicted", we want
23+
# to copy that into a separate object, for clearer names
24+
predictions <- object <- model
25+
model <- attributes(object)$model
26+
datagrid <- attributes(object)$datagrid
27+
28+
# vcov matrix, for adjusting se
29+
vcov_matrix <- .safe(stats::vcov(model, verbose = FALSE, ...))
30+
31+
minfo <- insight::model_info(model)
32+
33+
# model df
34+
dof <- insight::get_df(model, type = "wald", verbose = FALSE)
35+
crit_factor <- (1 + ci) / 2
36+
37+
## TODO: For Bayesian models, we always use the returned standard errors
38+
# need to check whether scale is always correct
39+
40+
# for non-Gaussian models, we need to adjust the standard errors
41+
if (!minfo$is_linear && !minfo$is_bayesian) {
42+
se_from_predictions <- tryCatch(
43+
{
44+
# arguments for predict(), to get SE on response scale for non-Gaussian models
45+
my_args <- list(
46+
model,
47+
newdata = datagrid,
48+
type = predict,
49+
se.fit = TRUE
50+
)
51+
# for mixed models, need to set re.form to NULL or NA
52+
if (insight::is_mixed_model(model)) {
53+
my_args$re.form <- NULL
54+
}
55+
do.call(stats::predict, my_args)
56+
},
57+
error = function(e) {
58+
e
59+
}
60+
)
61+
# check if everything worked as expected
62+
if (inherits(se_from_predictions, "error")) {
63+
insight::format_error(
64+
"This model (family) is probably not supported. The error that occured was:",
65+
se_from_predictions$message
66+
)
67+
}
68+
# check if we have standard errors
69+
if (is.null(se_from_predictions$se.fit)) {
70+
insight::format_error("Could not extract standard errors from predictions.")
71+
}
72+
preds_with_se <- merge(
73+
predictions,
74+
cbind(datagrid, se_prob = se_from_predictions$se.fit),
75+
sort = FALSE,
76+
all = TRUE
77+
)
78+
predictions$SE <- preds_with_se$se_prob
79+
} else {
80+
# for linear models, we don't need adjustment of standard errors
81+
vcov_matrix <- NULL
82+
}
83+
84+
# compute contrasts or comparisons
85+
out <- switch(comparison,
86+
pairwise = .compute_comparisons(predictions, dof, vcov_matrix, datagrid, contrast, by, crit_factor),
87+
interaction = .compute_interactions(predictions, dof, vcov_matrix, datagrid, contrast, by, crit_factor)
88+
)
89+
90+
# restore attributes, for formatting
91+
info <- attributes(object)
92+
attributes(out) <- utils::modifyList(attributes(out), info[.info_elements()])
93+
94+
# overwrite some of the attributes
95+
attr(out, "contrast") <- contrast
96+
attr(out, "focal_terms") <- c(contrast, by)
97+
attr(out, "by") <- by
98+
99+
# format output
100+
out <- format.marginaleffects_contrasts(out, model, p_adjust, comparison, ...)
101+
102+
# p-value adjustment?
103+
if (!is.null(p_adjust)) {
104+
out <- .p_adjust(model, out, p_adjust, verbose, ...)
105+
}
106+
107+
# Table formatting
108+
attr(out, "table_title") <- c("Model-based Contrasts Analysis", "blue")
109+
attr(out, "table_footer") <- .table_footer(
110+
out,
111+
by = contrast,
112+
type = "contrasts",
113+
model = model,
114+
info = info
115+
)
116+
117+
# Add attributes
118+
attr(out, "model") <- model
119+
attr(out, "response") <- insight::find_response(model)
120+
attr(out, "ci") <- ci
121+
attr(out, "p_adjust") <- p_adjust
122+
123+
# add attributes from workhorse function
124+
attributes(out) <- utils::modifyList(attributes(out), info[.info_elements()])
125+
126+
# Output
127+
class(out) <- unique(c("estimate_contrasts", "see_estimate_contrasts", class(out)))
128+
out
129+
}
130+
131+
132+
# pairwise comparisons ----------------------------------------------------
133+
.compute_comparisons <- function(predictions, dof, vcov_matrix, datagrid, contrast, by, crit_factor) {
134+
# we need the focal terms and all unique values from the datagrid
135+
focal_terms <- c(contrast, by)
136+
at_list <- lapply(datagrid[focal_terms], unique)
137+
138+
# pairwise comparisons are a bit more complicated, as we need to create
139+
# pairwise combinations of the levels of the focal terms.
140+
141+
# since we split at "." later, we need to replace "." in all levels
142+
# with a unique character combination
143+
at_list <- lapply(at_list, function(i) {
144+
gsub(".", "#_#", as.character(i), fixed = TRUE)
145+
})
146+
# create pairwise combinations
147+
level_pairs <- interaction(expand.grid(at_list))
148+
# using the matrix and then removing the lower triangle, we get all
149+
# pairwise combinations, except the ones that are the same
150+
M <- matrix(
151+
1,
152+
nrow = length(level_pairs),
153+
ncol = length(level_pairs),
154+
dimnames = list(level_pairs, level_pairs)
155+
)
156+
M[!upper.tri(M)] <- NA
157+
# table() works fine to create variables of this pairwise combinations
158+
pairs_data <- stats::na.omit(as.data.frame(as.table(M)))
159+
pairs_data$Freq <- NULL
160+
pairs_data <- lapply(pairs_data, as.character)
161+
# the levels are combined by ".", we need to split them and then create
162+
# a list of data frames, where each data frames contains the levels of
163+
# the focal terms as variables
164+
pairs_data <- lapply(pairs_data, function(i) {
165+
# split at ".", which is the separator char for levels
166+
pair <- strsplit(i, ".", fixed = TRUE)
167+
# since we replaced "." with "#_#" in original levels,
168+
# we need to replace it back here
169+
pair <- lapply(pair, gsub, pattern = "#_#", replacement = ".", fixed = TRUE)
170+
datawizard::data_rotate(as.data.frame(pair))
171+
})
172+
# now we iterate over all pairs and try to find the corresponding predictions
173+
out <- do.call(rbind, lapply(seq_len(nrow(pairs_data[[1]])), function(i) {
174+
pos1 <- predictions[[focal_terms[1]]] == pairs_data[[1]][i, 1]
175+
pos2 <- predictions[[focal_terms[1]]] == pairs_data[[2]][i, 1]
176+
177+
if (length(focal_terms) > 1) {
178+
pos1 <- pos1 & predictions[[focal_terms[2]]] == pairs_data[[1]][i, 2]
179+
pos2 <- pos2 & predictions[[focal_terms[2]]] == pairs_data[[2]][i, 2]
180+
}
181+
if (length(focal_terms) > 2) {
182+
pos1 <- pos1 & predictions[[focal_terms[3]]] == pairs_data[[1]][i, 3]
183+
pos2 <- pos2 & predictions[[focal_terms[3]]] == pairs_data[[2]][i, 3]
184+
}
185+
# once we have found the correct rows for the pairs, we can calculate
186+
# the contrast. We need the predicted values first
187+
predicted1 <- predictions$Predicted[pos1]
188+
predicted2 <- predictions$Predicted[pos2]
189+
190+
# we then create labels for the pairs. "result" is a data frame with
191+
# the labels (of the pairwise contrasts) as columns.
192+
result <- data.frame(
193+
Parameter = paste(
194+
paste0("(", paste(pairs_data[[1]][i, ], collapse = " "), ")"),
195+
paste0("(", paste(pairs_data[[2]][i, ], collapse = " "), ")"),
196+
sep = "-"
197+
),
198+
stringsAsFactors = FALSE
199+
)
200+
# we then add the contrast and the standard error. for linear models, the
201+
# SE is sqrt(se1^2 + se2^2).
202+
result$Difference <- predicted1 - predicted2
203+
# sum of squared standard errors
204+
sum_se_squared <- predictions$SE[pos1]^2 + predictions$SE[pos2]^2
205+
# for non-Gaussian models, we subtract the covariance of the two predictions
206+
# but only if the vcov_matrix is not NULL and has the correct dimensions
207+
correct_row_dims <- nrow(vcov_matrix) > 0 && all(nrow(vcov_matrix) >= which(pos1))
208+
correct_col_dims <- ncol(vcov_matrix) > 0 && all(ncol(vcov_matrix) >= which(pos2))
209+
if (is.null(vcov_matrix) || !correct_row_dims || !correct_col_dims) {
210+
vcov_sub <- 0
211+
} else {
212+
vcov_sub <- vcov_matrix[which(pos1), which(pos2)]^2
213+
}
214+
# Avoid negative values in sqrt()
215+
if (vcov_sub >= sum_se_squared) {
216+
result$SE <- sqrt(sum_se_squared)
217+
} else {
218+
result$SE <- sqrt(sum_se_squared - vcov_sub)
219+
}
220+
result
221+
}))
222+
# add CI and p-values
223+
out$CI_low <- out$Difference - stats::qt(crit_factor, df = dof) * out$SE
224+
out$CI_high <- out$Difference + stats::qt(crit_factor, df = dof) * out$SE
225+
out$Statistic <- out$Difference / out$SE
226+
out$p <- 2 * stats::pt(abs(out$Statistic), df = dof, lower.tail = FALSE)
227+
228+
# filter by "by"
229+
if (!is.null(by)) {
230+
idx <- rep_len(TRUE, nrow(out))
231+
for (filter_by in by) {
232+
# create index with "by" variables for each comparison pair
233+
filter_data <- do.call(cbind, lapply(pairs_data, function(i) {
234+
colnames(i) <- focal_terms
235+
i[filter_by]
236+
}))
237+
# check which pairs have identical values - these are the rows we want to keep
238+
idx <- idx & unname(apply(filter_data, 1, function(r) r[1] == r[2]))
239+
}
240+
# prepare filtered dataset
241+
filter_column <- pairs_data[[1]]
242+
colnames(filter_column) <- focal_terms
243+
# bind the filtered data to the output
244+
out <- cbind(filter_column[idx, by, drop = FALSE], out[idx, , drop = FALSE])
245+
}
246+
247+
rownames(out) <- NULL
248+
out
249+
}
250+
251+
252+
# interaction contrasts ----------------------------------------------------
253+
.compute_interactions <- function(predictions, dof, vcov_matrix, datagrid, contrast, by, crit_factor) {
254+
# we need the focal terms and all unique values from the datagrid
255+
focal_terms <- c(contrast, by)
256+
at_list <- lapply(datagrid[focal_terms], unique)
257+
258+
## TODO: interaction contrasts currently only work for two focal terms
259+
if (length(focal_terms) != 2) {
260+
insight::format_error("Interaction contrasts currently only work for two focal terms.")
261+
}
262+
263+
# create pairwise combinations of first focal term
264+
level_pairs <- at_list[[1]]
265+
M <- matrix(
266+
1,
267+
nrow = length(level_pairs),
268+
ncol = length(level_pairs),
269+
dimnames = list(level_pairs, level_pairs)
270+
)
271+
M[!upper.tri(M)] <- NA
272+
# table() works fine to create variables of this pairwise combinations
273+
pairs_focal1 <- stats::na.omit(as.data.frame(as.table(M)))
274+
pairs_focal1$Freq <- NULL
275+
276+
# create pairwise combinations of second focal term
277+
level_pairs <- at_list[[2]]
278+
M <- matrix(
279+
1,
280+
nrow = length(level_pairs),
281+
ncol = length(level_pairs),
282+
dimnames = list(level_pairs, level_pairs)
283+
)
284+
M[!upper.tri(M)] <- NA
285+
# table() works fine to create variables of this pairwise combinations
286+
pairs_focal2 <- stats::na.omit(as.data.frame(as.table(M)))
287+
pairs_focal2$Freq <- NULL
288+
289+
# now we iterate over all pairs and try to find the corresponding predictions
290+
out <- do.call(rbind, lapply(seq_len(nrow(pairs_focal1)), function(i) {
291+
# differences between levels of first focal term
292+
pos1 <- predictions[[focal_terms[1]]] == pairs_focal1[i, 1]
293+
pos2 <- predictions[[focal_terms[1]]] == pairs_focal1[i, 2]
294+
295+
do.call(rbind, lapply(seq_len(nrow(pairs_focal2)), function(j) {
296+
# difference between levels of first focal term, *within* first
297+
# level of second focal term
298+
pos_1a <- pos1 & predictions[[focal_terms[2]]] == pairs_focal2[j, 1]
299+
pos_1b <- pos2 & predictions[[focal_terms[2]]] == pairs_focal2[j, 1]
300+
# difference between levels of first focal term, *within* second
301+
# level of second focal term
302+
pos_2a <- pos1 & predictions[[focal_terms[2]]] == pairs_focal2[j, 2]
303+
pos_2b <- pos2 & predictions[[focal_terms[2]]] == pairs_focal2[j, 2]
304+
# once we have found the correct rows for the pairs, we can calculate
305+
# the contrast. We need the predicted values first
306+
predicted1 <- predictions$Predicted[pos_1a] - predictions$Predicted[pos_1b]
307+
predicted2 <- predictions$Predicted[pos_2a] - predictions$Predicted[pos_2b]
308+
# we then create labels for the pairs. "result" is a data frame with
309+
# the labels (of the pairwise contrasts) as columns.
310+
result <- data.frame(
311+
a = paste(pairs_focal1[i, 1], pairs_focal1[i, 2], sep = "-"),
312+
b = paste(pairs_focal2[j, 1], pairs_focal2[j, 2], sep = " and "),
313+
stringsAsFactors = FALSE
314+
)
315+
colnames(result) <- focal_terms
316+
# we then add the contrast and the standard error. for linear models, the
317+
# SE is sqrt(se1^2 + se2^2)
318+
result$Difference <- predicted1 - predicted2
319+
sum_se_squared <- sum(
320+
predictions$SE[pos_1a]^2, predictions$SE[pos_1b]^2,
321+
predictions$SE[pos_2a]^2, predictions$SE[pos_2b]^2
322+
)
323+
# for non-Gaussian models, we subtract the covariance of the two predictions
324+
# but only if the vcov_matrix is not NULL and has the correct dimensions
325+
correct_row_dims <- nrow(vcov_matrix) > 0 && all(nrow(vcov_matrix) >= which(pos_1a)) && all(nrow(vcov_matrix) >= which(pos_2a)) # nolint
326+
correct_col_dims <- ncol(vcov_matrix) > 0 && all(ncol(vcov_matrix) >= which(pos_1b)) && all(ncol(vcov_matrix) >= which(pos_2b)) # nolint
327+
if (is.null(vcov_matrix) || !correct_row_dims || !correct_col_dims) {
328+
vcov_sub <- 0
329+
} else {
330+
vcov_sub <- sum(
331+
vcov_matrix[which(pos_1a), which(pos_1b)]^2,
332+
vcov_matrix[which(pos_2a), which(pos_2b)]^2
333+
)
334+
}
335+
# Avoid negative values in sqrt()
336+
if (vcov_sub >= sum_se_squared) {
337+
result$SE <- sqrt(sum_se_squared)
338+
} else {
339+
result$SE <- sqrt(sum_se_squared - vcov_sub)
340+
}
341+
result
342+
}))
343+
}))
344+
# add CI and p-values
345+
out$CI_low <- out$Difference - stats::qt(crit_factor, df = dof) * out$SE
346+
out$CI_high <- out$Difference + stats::qt(crit_factor, df = dof) * out$SE
347+
out$Statistic <- out$Difference / out$SE
348+
out$p <- 2 * stats::pt(abs(out$Statistic), df = dof, lower.tail = FALSE)
349+
out
350+
}

0 commit comments

Comments
 (0)