|
41 | 41 | #'
|
42 | 42 | #' @examples
|
43 | 43 | #' \donttest{
|
44 |
| -#' # Fit a depth two tree on doubly robust treatment effect estimates from a causal forest. |
| 44 | +#' # Construct doubly robust scores using a causal forest. |
45 | 45 | #' n <- 10000
|
46 | 46 | #' p <- 10
|
47 |
| -#' # Discretizing continuous covariates decreases runtime. |
| 47 | +#' # Discretizing continuous covariates decreases runtime for policy learning. |
48 | 48 | #' X <- round(matrix(rnorm(n * p), n, p), 2)
|
49 | 49 | #' colnames(X) <- make.names(1:p)
|
50 | 50 | #' W <- rbinom(n, 1, 1 / (1 + exp(X[, 3])))
|
51 | 51 | #' tau <- 1 / (1 + exp((X[, 1] + X[, 2]) / 2)) - 0.5
|
52 | 52 | #' Y <- X[, 3] + W * tau + rnorm(n)
|
53 | 53 | #' c.forest <- grf::causal_forest(X, Y, W)
|
| 54 | +#' |
| 55 | +#' # Retrieve doubly robust scores. |
54 | 56 | #' dr.scores <- double_robust_scores(c.forest)
|
55 | 57 | #'
|
56 |
| -#' tree <- policy_tree(X, dr.scores, 2) |
| 58 | +#' # Learn a depth-2 tree on a training set. |
| 59 | +#' train <- sample(1:n, n / 2) |
| 60 | +#' tree <- policy_tree(X[train, ], dr.scores[train, ], 2) |
57 | 61 | #' tree
|
58 | 62 | #'
|
59 |
| -#' # Predict treatment assignment. |
60 |
| -#' predicted <- predict(tree, X) |
| 63 | +#' # Evaluate the tree on a test set. |
| 64 | +#' test <- -train |
61 | 65 | #'
|
62 |
| -#' plot(X[, 1], X[, 2], col = predicted) |
63 |
| -#' legend("topright", c("control", "treat"), col = c(1, 2), pch = 19) |
64 |
| -#' abline(0, -1, lty = 2) |
| 66 | +#' # One way to assess the policy is to see whether the leaf node (group) the test set samples |
| 67 | +#' # are predicted to belong to have mean outcomes in accordance with the prescribed policy. |
65 | 68 | #'
|
66 |
| -#' # Predict the leaf assigned to each sample. |
67 |
| -#' node.id <- predict(tree, X, type = "node.id") |
68 |
| -#' # Can be reshaped to a list of samples per leaf node with `split`. |
69 |
| -#' samples.per.leaf <- split(1:n, node.id) |
| 69 | +#' # Get the leaf node assigned to each test sample. |
| 70 | +#' node.id <- predict(tree, X[test, ], type = "node.id") |
70 | 71 | #'
|
71 |
| -#' # The value of all arms (along with SEs) by each leaf node. |
72 |
| -#' values <- aggregate(dr.scores, by = list(leaf.node = node.id), |
73 |
| -#' FUN = function(x) c(mean = mean(x), se = sd(x) / sqrt(length(x)))) |
74 |
| -#' print(values, digits = 2) |
| 72 | +#' # Doubly robust estimates of E[Y(control)] and E[Y(treated)] by leaf node. |
| 73 | +#' values <- aggregate(dr.scores[test, ], by = list(leaf.node = node.id), |
| 74 | +#' FUN = function(dr) c(mean = mean(dr), se = sd(dr) / sqrt(length(dr)))) |
| 75 | +#' print(values, digits = 1) |
75 | 76 | #'
|
76 |
| -#' # Take cost of treatment into account by offsetting the objective |
| 77 | +#' # Take cost of treatment into account by, for example, offsetting the objective |
77 | 78 | #' # with an estimate of the average treatment effect.
|
78 |
| -#' # See section 5.1 in Athey and Wager (2021) for more details, including |
79 |
| -#' # suggestions on using cross-validation to assess the accuracy of the learned policy. |
80 | 79 | #' ate <- grf::average_treatment_effect(c.forest)
|
81 | 80 | #' cost.offset <- ate[["estimate"]]
|
82 | 81 | #' dr.scores[, "treated"] <- dr.scores[, "treated"] - cost.offset
|
83 | 82 | #' tree.cost <- policy_tree(X, dr.scores, 2)
|
84 | 83 | #'
|
85 |
| -#' # If there are too many covariates to make tree search computationally feasible, |
86 |
| -#' # one can consider for example only the top 5 features according to GRF's variable importance. |
| 84 | +#' # Predict treatment assignment for each sample. |
| 85 | +#' predicted <- predict(tree, X) |
| 86 | +#' |
| 87 | +#' # If there are too many covariates to make tree search computationally feasible, then one |
| 88 | +#' # approach is to consider for example only the top features according to GRF's variable importance. |
87 | 89 | #' var.imp <- grf::variable_importance(c.forest)
|
88 | 90 | #' top.5 <- order(var.imp, decreasing = TRUE)[1:5]
|
89 | 91 | #' tree.top5 <- policy_tree(X[, top.5], dr.scores, 2, split.step = 50)
|
| 92 | +#' |
90 | 93 | #' }
|
91 | 94 | #' @seealso \code{\link{hybrid_policy_tree}} for building deeper trees.
|
92 | 95 | #' @export
|
|
0 commit comments