Skip to content

Commit 0222a6a

Browse files
authored
+ predict.doc (#167)
1 parent f75079f commit 0222a6a

File tree

3 files changed

+46
-44
lines changed

3 files changed

+46
-44
lines changed

r-package/policytree/R/policy_tree.R

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@
8989
#' var.imp <- grf::variable_importance(c.forest)
9090
#' top.5 <- order(var.imp, decreasing = TRUE)[1:5]
9191
#' tree.top5 <- policy_tree(X[, top.5], dr.scores, 2, split.step = 50)
92-
#'
9392
#' }
9493
#' @seealso \code{\link{hybrid_policy_tree}} for building deeper trees.
9594
#' @export
@@ -199,49 +198,51 @@ policy_tree <- function(X, Gamma, depth = 2, split.step = 1, min.node.size = 1,
199198
#' @method predict policy_tree
200199
#' @examples
201200
#' \donttest{
202-
#' # Fit a depth two tree on doubly robust treatment effect estimates from a causal forest.
201+
#' # Construct doubly robust scores using a causal forest.
203202
#' n <- 10000
204203
#' p <- 10
205-
#' # Discretizing continuous covariates decreases runtime.
204+
#' # Discretizing continuous covariates decreases runtime for policy learning.
206205
#' X <- round(matrix(rnorm(n * p), n, p), 2)
207206
#' colnames(X) <- make.names(1:p)
208207
#' W <- rbinom(n, 1, 1 / (1 + exp(X[, 3])))
209208
#' tau <- 1 / (1 + exp((X[, 1] + X[, 2]) / 2)) - 0.5
210209
#' Y <- X[, 3] + W * tau + rnorm(n)
211210
#' c.forest <- grf::causal_forest(X, Y, W)
211+
#'
212+
#' # Retrieve doubly robust scores.
212213
#' dr.scores <- double_robust_scores(c.forest)
213214
#'
214-
#' tree <- policy_tree(X, dr.scores, 2)
215+
#' # Learn a depth-2 tree on a training set.
216+
#' train <- sample(1:n, n / 2)
217+
#' tree <- policy_tree(X[train, ], dr.scores[train, ], 2)
215218
#' tree
216219
#'
217-
#' # Predict treatment assignment.
218-
#' predicted <- predict(tree, X)
220+
#' # Evaluate the tree on a test set.
221+
#' test <- -train
219222
#'
220-
#' plot(X[, 1], X[, 2], col = predicted)
221-
#' legend("topright", c("control", "treat"), col = c(1, 2), pch = 19)
222-
#' abline(0, -1, lty = 2)
223+
#' # One way to assess the policy is to see whether the leaf node (group) the test set samples
224+
#' # are predicted to belong to have mean outcomes in accordance with the prescribed policy.
223225
#'
224-
#' # Predict the leaf assigned to each sample.
225-
#' node.id <- predict(tree, X, type = "node.id")
226-
#' # Can be reshaped to a list of samples per leaf node with `split`.
227-
#' samples.per.leaf <- split(1:n, node.id)
226+
#' # Get the leaf node assigned to each test sample.
227+
#' node.id <- predict(tree, X[test, ], type = "node.id")
228228
#'
229-
#' # The value of all arms (along with SEs) by each leaf node.
230-
#' values <- aggregate(dr.scores, by = list(leaf.node = node.id),
231-
#' FUN = function(x) c(mean = mean(x), se = sd(x) / sqrt(length(x))))
232-
#' print(values, digits = 2)
229+
#' # Doubly robust estimates of E[Y(control)] and E[Y(treated)] by leaf node.
230+
#' values <- aggregate(dr.scores[test, ], by = list(leaf.node = node.id),
231+
#' FUN = function(dr) c(mean = mean(dr), se = sd(dr) / sqrt(length(dr))))
232+
#' print(values, digits = 1)
233233
#'
234-
#' # Take cost of treatment into account by offsetting the objective
234+
#' # Take cost of treatment into account by, for example, offsetting the objective
235235
#' # with an estimate of the average treatment effect.
236-
#' # See section 5.1 in Athey and Wager (2021) for more details, including
237-
#' # suggestions on using cross-validation to assess the accuracy of the learned policy.
238236
#' ate <- grf::average_treatment_effect(c.forest)
239237
#' cost.offset <- ate[["estimate"]]
240238
#' dr.scores[, "treated"] <- dr.scores[, "treated"] - cost.offset
241239
#' tree.cost <- policy_tree(X, dr.scores, 2)
242240
#'
243-
#' # If there are too many covariates to make tree search computationally feasible,
244-
#' # one can consider for example only the top 5 features according to GRF's variable importance.
241+
#' # Predict treatment assignment for each sample.
242+
#' predicted <- predict(tree, X)
243+
#'
244+
#' # If there are too many covariates to make tree search computationally feasible, then one
245+
#' # approach is to consider for example only the top features according to GRF's variable importance.
245246
#' var.imp <- grf::variable_importance(c.forest)
246247
#' top.5 <- order(var.imp, decreasing = TRUE)[1:5]
247248
#' tree.top5 <- policy_tree(X[, top.5], dr.scores, 2, split.step = 50)

r-package/policytree/man/policy_tree.Rd

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

r-package/policytree/man/predict.policy_tree.Rd

Lines changed: 23 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)