|
89 | 89 | #' var.imp <- grf::variable_importance(c.forest)
|
90 | 90 | #' top.5 <- order(var.imp, decreasing = TRUE)[1:5]
|
91 | 91 | #' tree.top5 <- policy_tree(X[, top.5], dr.scores, 2, split.step = 50)
|
92 |
| -#' |
93 | 92 | #' }
|
94 | 93 | #' @seealso \code{\link{hybrid_policy_tree}} for building deeper trees.
|
95 | 94 | #' @export
|
@@ -199,49 +198,51 @@ policy_tree <- function(X, Gamma, depth = 2, split.step = 1, min.node.size = 1,
|
199 | 198 | #' @method predict policy_tree
|
200 | 199 | #' @examples
|
201 | 200 | #' \donttest{
|
202 |
| -#' # Fit a depth two tree on doubly robust treatment effect estimates from a causal forest. |
| 201 | +#' # Construct doubly robust scores using a causal forest. |
203 | 202 | #' n <- 10000
|
204 | 203 | #' p <- 10
|
205 |
| -#' # Discretizing continuous covariates decreases runtime. |
| 204 | +#' # Discretizing continuous covariates decreases runtime for policy learning. |
206 | 205 | #' X <- round(matrix(rnorm(n * p), n, p), 2)
|
207 | 206 | #' colnames(X) <- make.names(1:p)
|
208 | 207 | #' W <- rbinom(n, 1, 1 / (1 + exp(X[, 3])))
|
209 | 208 | #' tau <- 1 / (1 + exp((X[, 1] + X[, 2]) / 2)) - 0.5
|
210 | 209 | #' Y <- X[, 3] + W * tau + rnorm(n)
|
211 | 210 | #' c.forest <- grf::causal_forest(X, Y, W)
|
| 211 | +#' |
| 212 | +#' # Retrieve doubly robust scores. |
212 | 213 | #' dr.scores <- double_robust_scores(c.forest)
|
213 | 214 | #'
|
214 |
| -#' tree <- policy_tree(X, dr.scores, 2) |
| 215 | +#' # Learn a depth-2 tree on a training set. |
| 216 | +#' train <- sample(1:n, n / 2) |
| 217 | +#' tree <- policy_tree(X[train, ], dr.scores[train, ], 2) |
215 | 218 | #' tree
|
216 | 219 | #'
|
217 |
| -#' # Predict treatment assignment. |
218 |
| -#' predicted <- predict(tree, X) |
| 220 | +#' # Evaluate the tree on a test set. |
| 221 | +#' test <- -train |
219 | 222 | #'
|
220 |
| -#' plot(X[, 1], X[, 2], col = predicted) |
221 |
| -#' legend("topright", c("control", "treat"), col = c(1, 2), pch = 19) |
222 |
| -#' abline(0, -1, lty = 2) |
| 223 | +#' # One way to assess the policy is to see whether the leaf node (group) the test set samples |
| 224 | +#' # are predicted to belong to have mean outcomes in accordance with the prescribed policy. |
223 | 225 | #'
|
224 |
| -#' # Predict the leaf assigned to each sample. |
225 |
| -#' node.id <- predict(tree, X, type = "node.id") |
226 |
| -#' # Can be reshaped to a list of samples per leaf node with `split`. |
227 |
| -#' samples.per.leaf <- split(1:n, node.id) |
| 226 | +#' # Get the leaf node assigned to each test sample. |
| 227 | +#' node.id <- predict(tree, X[test, ], type = "node.id") |
228 | 228 | #'
|
229 |
| -#' # The value of all arms (along with SEs) by each leaf node. |
230 |
| -#' values <- aggregate(dr.scores, by = list(leaf.node = node.id), |
231 |
| -#' FUN = function(x) c(mean = mean(x), se = sd(x) / sqrt(length(x)))) |
232 |
| -#' print(values, digits = 2) |
| 229 | +#' # Doubly robust estimates of E[Y(control)] and E[Y(treated)] by leaf node. |
| 230 | +#' values <- aggregate(dr.scores[test, ], by = list(leaf.node = node.id), |
| 231 | +#' FUN = function(dr) c(mean = mean(dr), se = sd(dr) / sqrt(length(dr)))) |
| 232 | +#' print(values, digits = 1) |
233 | 233 | #'
|
234 |
| -#' # Take cost of treatment into account by offsetting the objective |
| 234 | +#' # Take cost of treatment into account by, for example, offsetting the objective |
235 | 235 | #' # with an estimate of the average treatment effect.
|
236 |
| -#' # See section 5.1 in Athey and Wager (2021) for more details, including |
237 |
| -#' # suggestions on using cross-validation to assess the accuracy of the learned policy. |
238 | 236 | #' ate <- grf::average_treatment_effect(c.forest)
|
239 | 237 | #' cost.offset <- ate[["estimate"]]
|
240 | 238 | #' dr.scores[, "treated"] <- dr.scores[, "treated"] - cost.offset
|
241 | 239 | #' tree.cost <- policy_tree(X, dr.scores, 2)
|
242 | 240 | #'
|
243 |
| -#' # If there are too many covariates to make tree search computationally feasible, |
244 |
| -#' # one can consider for example only the top 5 features according to GRF's variable importance. |
| 241 | +#' # Predict treatment assignment for each sample. |
| 242 | +#' predicted <- predict(tree, X) |
| 243 | +#' |
| 244 | +#' # If there are too many covariates to make tree search computationally feasible, then one |
| 245 | +#' # approach is to consider for example only the top features according to GRF's variable importance. |
245 | 246 | #' var.imp <- grf::variable_importance(c.forest)
|
246 | 247 | #' top.5 <- order(var.imp, decreasing = TRUE)[1:5]
|
247 | 248 | #' tree.top5 <- policy_tree(X[, top.5], dr.scores, 2, split.step = 50)
|
|
0 commit comments