Skip to content

Commit e3a94ed

Browse files
committed
Cleaning and doc updates
1 parent 7a6f5c6 commit e3a94ed

23 files changed

+339
-409
lines changed

R/functions_for_processing.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,15 @@
452452
if (is_null(treated) || treated %nin% unique.vals) {
453453
.err('when `estimand = "ATT"` for multi-category treatments, an argument must be supplied to `focal`')
454454
}
455+
455456
focal <- treated
456457
}
457458
}
458459
else if (estimand == "ATC") {
459460
if (is_null(focal)) {
460461
.err('when `estimand = "ATC"` for multi-category treatments, an argument must be supplied to `focal`')
461462
}
463+
462464
estimand <- "ATT"
463465
}
464466
}

R/utils.R

Lines changed: 15 additions & 242 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,6 @@ hasbar <- function(term) {
608608
any(c("|", "||") %in% all.names(term))
609609
}
610610
get_varnames <- function(expr) {
611-
# Ensure we are working with a formula RHS
612611
recurse <- function(e) {
613612
if (is.symbol(e)) {
614613
# bare variable like age
@@ -619,248 +618,21 @@ get_varnames <- function(expr) {
619618
return(NULL)
620619
}
621620

622-
fn <- as.character(e[[1L]])
623-
621+
# keep as-is for $, [[, and [
622+
fn <- e[[1L]]
624623
if (fn == as.name("$") || fn == as.name("[[") || fn == as.name("[")) {
625-
# if (fn %in% c("$", "[[", "[")) {
626-
627-
# keep as-is for $, [[, and [
628624
return(deparse1(e))
629625
}
630626

631627
# strip outer function, recurse into arguments
632-
unlist(lapply(as.list(e)[-1L], recurse))
633-
628+
lapply(as.list(e)[-1L], recurse) |>
629+
unlist()
634630
}
635631

636632
recurse(expr)
637633
}
638634

639635
#treat/covs
640-
get_covs_and_treat_from_formula <- function(f, data = NULL, terms = FALSE, sep = "", ...) {
641-
642-
if (!rlang::is_formula(f)) {
643-
.err("`formula` must be a formula")
644-
}
645-
646-
env <- environment(f)
647-
648-
#Check if data exists
649-
if (is_null(data)) {
650-
data <- env
651-
data.specified <- FALSE
652-
}
653-
else if (is.data.frame(data)) {
654-
data.specified <- TRUE
655-
}
656-
else {
657-
.wrn("the argument supplied to `data` is not a data.frame object. This may causes errors or unexpected results")
658-
data <- env
659-
data.specified <- FALSE
660-
}
661-
662-
eval.model.matrx <- !hasbar(f)
663-
664-
tryCatch({
665-
tt <- terms(f, data = data)
666-
},
667-
error = function(e) {
668-
msg <- {
669-
if (conditionMessage(e) == "'.' in formula and no 'data' argument")
670-
"`.` is not allowed in formulas"
671-
else
672-
conditionMessage(e)
673-
}
674-
.err(msg)
675-
})
676-
677-
treat <- ...get("treat")
678-
treat.name <- NULL
679-
680-
#Check if response exists
681-
if (rlang::is_formula(tt, lhs = TRUE)) {
682-
resp.var.mentioned <- attr(tt, "variables")[[2L]]
683-
resp.var.mentioned.char <- deparse1(resp.var.mentioned)
684-
685-
resp.var.failed <- {
686-
test <- tryCatch(eval(resp.var.mentioned, data, env), error = function(e) e)
687-
if (!inherits(test, "simpleError")) {
688-
is_null(test)
689-
}
690-
else if (startsWith(conditionMessage(test), "object") &&
691-
endsWith(conditionMessage(test), "not found")) {
692-
TRUE
693-
}
694-
else {
695-
.err(conditionMessage(test), tidy = FALSE)
696-
}
697-
}
698-
699-
if (resp.var.failed) {
700-
if (is_null(treat)) {
701-
.err(sprintf("the given response variable, %s, is not a variable in %s",
702-
add_quotes(resp.var.mentioned.char),
703-
word_list(c("data", "the global environment")[c(data.specified, TRUE)], "or")))
704-
}
705-
tt <- delete.response(tt)
706-
}
707-
708-
if (!resp.var.failed) {
709-
treat.name <- resp.var.mentioned.char
710-
treat <- eval(resp.var.mentioned, data, env)
711-
}
712-
}
713-
714-
#Check if RHS variables exist
715-
tt.covs <- delete.response(tt)
716-
717-
rhs.vars.mentioned <- attr(tt.covs, "variables")[-1L]
718-
rhs.vars.mentioned.char <- vapply(rhs.vars.mentioned, deparse1, character(1L))
719-
rhs.vars.failed <- vapply(seq_along(rhs.vars.mentioned), function(i) {
720-
test <- tryCatch(eval(rhs.vars.mentioned[[i]], data, env), error = function(e) e)
721-
if (!inherits(test, "simpleError")) {
722-
return(is_null(test))
723-
}
724-
725-
if (!startsWith(conditionMessage(test), "object") ||
726-
!endsWith(conditionMessage(test), "not found")) {
727-
.err(conditionMessage(test), tidy = FALSE)
728-
}
729-
730-
TRUE
731-
}, logical(1L))
732-
733-
if (any(rhs.vars.failed)) {
734-
.err(sprintf("All variables in `formula` must be variables in `data` or objects in the global environment.\nMissing variables: %s",
735-
word_list(rhs.vars.mentioned.char[rhs.vars.failed], and.or = FALSE)), tidy = FALSE)
736-
737-
}
738-
739-
rhs.term.labels <- attr(tt.covs, "term.labels")
740-
rhs.term.orders <- attr(tt.covs, "order")
741-
742-
rhs.df <- setNames(vapply(rhs.vars.mentioned, function(v) {
743-
length(dim(try(eval(v, data, env), silent = TRUE))) == 2L
744-
}, logical(1L)), rhs.vars.mentioned.char)
745-
746-
rhs.term.labels.list <- setNames(as.list(rhs.term.labels), rhs.term.labels)
747-
if (any(rhs.df)) {
748-
if (any(rhs.vars.mentioned.char[rhs.df] %in% unlist(lapply(rhs.term.labels[rhs.term.orders > 1],
749-
strsplit, ":", fixed = TRUE)))) {
750-
.err("interactions with data.frames are not allowed in the input formula")
751-
}
752-
753-
addl.dfs <- setNames(lapply(which(rhs.df), function(i) {
754-
df <- eval(rhs.vars.mentioned[[i]], data, env)
755-
if (inherits(df, "rms")) {
756-
class(df) <- "matrix"
757-
df <- setNames(as.data.frame(as.matrix(df)), attr(df, "colnames"))
758-
}
759-
else if (can_str2num(colnames(df))) {
760-
colnames(df) <- paste(rhs.vars.mentioned.char[i], colnames(df), sep = sep)
761-
}
762-
763-
as.data.frame(df)
764-
}),
765-
rhs.vars.mentioned.char[rhs.df])
766-
767-
for (i in rhs.term.labels[rhs.term.labels %in% rhs.vars.mentioned.char[rhs.df]]) {
768-
ind <- which(rhs.term.labels == i)
769-
rhs.term.labels <- append(rhs.term.labels[-ind],
770-
values = names(addl.dfs[[i]]),
771-
after = ind - 1L)
772-
rhs.term.labels.list[[i]] <- names(addl.dfs[[i]])
773-
}
774-
775-
data <- {
776-
if (data.specified) do.call("cbind", unname(c(addl.dfs, list(data))))
777-
else do.call("cbind", unname(addl.dfs))
778-
}
779-
}
780-
781-
if (is_null(rhs.term.labels)) {
782-
new.form <- as.formula("~ 0")
783-
tt.covs <- terms(new.form)
784-
covs <- data.frame(Intercept = rep.int(1, if (is_null(treat)) 1L else length(treat)))[, -1L, drop = FALSE]
785-
}
786-
else {
787-
new.form.char <- sprintf("~ %s", paste(vapply(names(rhs.term.labels.list), function(x) {
788-
if (x %in% rhs.vars.mentioned.char[rhs.df]) paste0("`", rhs.term.labels.list[[x]], "`", collapse = " + ")
789-
else rhs.term.labels.list[[x]]
790-
} , character(1L)), collapse = " + "))
791-
792-
new.form <- as.formula(new.form.char)
793-
tt.covs <- terms(update(new.form, ~ . - 1))
794-
795-
#Get model.frame, report error
796-
mf.covs <- quote(stats::model.frame(tt.covs, data,
797-
drop.unused.levels = TRUE,
798-
na.action = "na.pass"))
799-
800-
covs <- tryCatch(eval(mf.covs),
801-
error = function(e) {
802-
.err(conditionMessage(e), tidy = FALSE)
803-
})
804-
805-
if (is_not_null(treat.name) && utils::hasName(covs, treat.name)) {
806-
.err("the variable on the left side of the formula appears on the right side too")
807-
}
808-
}
809-
810-
if (eval.model.matrx) {
811-
if (!is.character(sep) || length(sep) > 1L) {
812-
stop("'sep' must be a string of length 1.", call. = FALSE)
813-
}
814-
815-
s <- nzchar(sep)
816-
817-
if (s) original.covs.levels <- make_list(names(covs))
818-
819-
for (i in names(covs)) {
820-
if (is.character(covs[[i]])) {
821-
covs[[i]] <- factor(covs[[i]])
822-
}
823-
else if (!is.factor(covs[[i]])) {
824-
next
825-
}
826-
827-
if (length(unique(covs[[i]])) == 1L) {
828-
covs[[i]] <- 1
829-
}
830-
else if (s) {
831-
original.covs.levels[[i]] <- levels(covs[[i]])
832-
levels(covs[[i]]) <- paste0(sep, original.covs.levels[[i]])
833-
}
834-
}
835-
836-
#Get full model matrix with interactions too
837-
covs.matrix <- model.matrix(tt.covs, data = covs,
838-
contrasts.arg = lapply(Filter(is.factor, covs),
839-
contrasts, contrasts = FALSE))
840-
841-
if (s) {
842-
for (i in names(covs)[vapply(covs, is.factor, logical(1L))]) {
843-
levels(covs[[i]]) <- original.covs.levels[[i]]
844-
}
845-
}
846-
}
847-
else {
848-
covs.matrix <- NULL
849-
}
850-
851-
if (!terms) {
852-
attr(covs, "terms") <- NULL
853-
}
854-
855-
if (is_not_null(treat)) {
856-
class(treat) <- unique(c("treat", class(treat)))
857-
attr(treat, "treat.name") <- treat.name
858-
}
859-
860-
list(reported.covs = covs,
861-
model.covs = covs.matrix,
862-
treat = treat)
863-
}
864636
get_covs_and_treat_from_formula2 <- function(f, data = NULL, sep = "", ...) {
865637

866638
if (!rlang::is_formula(f)) {
@@ -1286,8 +1058,8 @@ rep_with <- function(x, y) {
12861058
rep.int(x, length(y)) |>
12871059
setNames(names(y))
12881060
}
1289-
is_null <- function(x) length(x) == 0L
1290-
is_not_null <- function(x) !is_null(x)
1061+
is_null <- function(x) {length(x) == 0L}
1062+
is_not_null <- function(x) {!is_null(x)}
12911063
if_null_then <- function(x1 = NULL, x2 = NULL, ...) {
12921064
if (is_not_null(x1)) {
12931065
return(x1)
@@ -1472,10 +1244,7 @@ check_if_call_from_fun <- function(fun) {
14721244

14731245
#Evaluate a call (usually a model call) with options for ignoring and recoding
14741246
#warnings and errors.
1475-
.eval_fit <- function(call,
1476-
envir = parent.frame(2L),
1477-
warnings = NULL,
1478-
errors = NULL,
1247+
.eval_fit <- function(call, envir = parent.frame(2L), warnings = NULL, errors = NULL,
14791248
from = TRUE) {
14801249
withCallingHandlers({
14811250
fit <- eval(call, envir = envir)
@@ -1493,10 +1262,12 @@ check_if_call_from_fun <- function(fun) {
14931262
.wrn(w, tidy = FALSE)
14941263
}
14951264
else if (isTRUE(from)) {
1496-
.wrn(sprintf("(from `%s()`) %s", rlang::call_name(call), w), tidy = FALSE)
1265+
.wrn(sprintf("(from `%s()`): %s", rlang::call_name(call), w),
1266+
tidy = FALSE)
14971267
}
14981268
else {
1499-
.wrn(sprintf("(from %s) %s", paste(from, collapse = ""), w), tidy = FALSE)
1269+
.wrn(sprintf("(from %s): %s", paste(from, collapse = ""), w),
1270+
tidy = FALSE)
15001271
}
15011272

15021273
invokeRestart("muffleWarning")
@@ -1514,10 +1285,12 @@ check_if_call_from_fun <- function(fun) {
15141285
.err(e, tidy = FALSE)
15151286
}
15161287
else if (isTRUE(from)) {
1517-
.err(sprintf("(from `%s()`) %s", rlang::call_name(call), e), tidy = FALSE)
1288+
.err(sprintf("(from `%s()`): %s", rlang::call_name(call), e),
1289+
tidy = FALSE)
15181290
}
15191291
else {
1520-
.err(sprintf("(from %s) %s", paste(from, collapse = ""), e), tidy = FALSE)
1292+
.err(sprintf("(from %s): %s", paste(from, collapse = ""), e),
1293+
tidy = FALSE)
15211294
}
15221295
})
15231296

R/weightit2bart.R

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,22 +164,28 @@
164164
#' (W1 <- weightit(treat ~ age + educ + married +
165165
#' nodegree + re74, data = lalonde,
166166
#' method = "bart", estimand = "ATT"))
167+
#'
167168
#' summary(W1)
169+
#'
168170
#' cobalt::bal.tab(W1)
169171
#'
170172
#' #Balancing covariates with respect to race (multi-category)
171173
#' (W2 <- weightit(race ~ age + educ + married +
172174
#' nodegree + re74, data = lalonde,
173175
#' method = "bart", estimand = "ATE"))
176+
#'
174177
#' summary(W2)
178+
#'
175179
#' cobalt::bal.tab(W2)
176180
#'
177181
#' #Balancing covariates with respect to re75 (continuous)
178182
#' #assuming t(3) conditional density for treatment
179183
#' (W3 <- weightit(re75 ~ age + educ + married +
180184
#' nodegree + re74, data = lalonde,
181185
#' method = "bart", density = "dt_3"))
186+
#'
182187
#' summary(W3)
188+
#'
183189
#' cobalt::bal.tab(W3)}
184190
NULL
185191

@@ -226,8 +232,9 @@ weightit2bart <- function(covs, treat, s.weights, subset, estimand, focal, stabi
226232
fit <- eval(bart.call)
227233
}, verbose = verbose)},
228234
error = function(e) {
229-
e. <- conditionMessage(e)
230-
.err("(from `dbarts::bart2()`) ", e., tidy = FALSE)
235+
.err(sprintf("(from `dbarts::bart2()`): %s",
236+
conditionMessage(e)),
237+
tidy = FALSE)
231238
})
232239

233240
p.score <- fitted(fit)
@@ -282,8 +289,9 @@ weightit2bart.multi <- function(covs, treat, s.weights, subset, estimand, focal
282289
fit.list[[i]] <- eval(bart.call)
283290
}, verbose = verbose)},
284291
error = function(e) {
285-
e. <- conditionMessage(e)
286-
.err("(from `dbarts::bart2()`) ", e., tidy = FALSE)
292+
.err(sprintf("(from `dbarts::bart2()`): %s",
293+
conditionMessage(e)),
294+
tidy = FALSE)
287295
})
288296

289297
ps[[i]] <- fitted(fit.list[[i]])
@@ -340,8 +348,9 @@ weightit2bart.cont <- function(covs, treat, s.weights, subset, stabilize, missin
340348
fit <- eval(bart.call)
341349
}, verbose = verbose)},
342350
error = function(e) {
343-
e. <- conditionMessage(e)
344-
.err("(from `dbarts::bart2()`) ", e., tidy = FALSE)
351+
.err(sprintf("(from `dbarts::bart2()`): %s",
352+
conditionMessage(e)),
353+
tidy = FALSE)
345354
})
346355

347356
r <- residuals(fit)

0 commit comments

Comments
 (0)