Skip to content

Commit e9cf9c1

Browse files
authored
Add maq_scale (#104)
1 parent d8437e2 commit e9cf9c1

File tree

7 files changed

+112
-10
lines changed

7 files changed

+112
-10
lines changed

r-package/maq/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export(get_aipw_scores)
1010
export(get_ipw_scores)
1111
export(integrated_difference)
1212
export(maq)
13+
export(maq_scale)
1314
importFrom(Rcpp,evalCpp)
1415
importFrom(stats,predict)
1516
useDynLib(maq, .registration = TRUE)

r-package/maq/R/maq.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,46 @@ integrated_difference <- function(object.lhs,
548548

549549
c(estimate = point.estimate, std.err = std.err)
550550
}
551+
552+
#' Scale a Qini curve.
553+
#'
554+
#' Remap the policy value and budget to some problem-specific application.
555+
#' This is a convenience function that is usually useful for plotting.
556+
#'
557+
#' @param object A maq object.
558+
#' @param scale A numeric value to scale by.
559+
#'
560+
#' @return A rescaled maq object.
561+
#'
562+
#' @examples
563+
#' \donttest{
564+
#' # Generate some single-arm toy data.
565+
#' n <- 1500
566+
#' K <- 1
567+
#' reward <- matrix(1 + runif(n * K), n, K)
568+
#' scores <- reward + 5 * matrix(rnorm(n * K), n, K)
569+
#' cost <- 1
570+
#'
571+
#' # Fit a Qini curve.
572+
#' qini <- maq(reward, cost, scores, R = 200)
573+
#'
574+
#' # Plot the policy values as we vary the fraction treated.
575+
#' qini |>
576+
#' plot(xlab = "Treated fraction")
577+
#'
578+
#' # Plot the policy values for a maximum allocation of, for example, 500 units.
579+
#' maq_scale(qini, 500) |>
580+
#' plot(xlab = "Units treated")
581+
#'}
582+
#' @export
583+
maq_scale <- function(object,
584+
scale = 1) {
585+
object[["_path"]]$spend <- scale * object[["_path"]]$spend
586+
object[["_path"]]$gain <- scale * object[["_path"]]$gain
587+
object[["_path"]]$std.err <- scale * object[["_path"]]$std.err
588+
object[["_path"]]$gain.bs <- lapply(object[["_path"]]$gain.bs,
589+
function(x) scale * x)
590+
object$budget <- object$budget * scale
591+
592+
object
593+
}

r-package/maq/R/plot.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
#' \donttest{
2424
#' if (require("ggplot2", quietly = TRUE)) {
2525
#' # Generate toy data and customize plots.
26-
#' n = 500
27-
#' K = 1
28-
#' reward = matrix(1 + rnorm(n * K), n, K)
29-
#' scores = reward + matrix(rnorm(n * K), n, K)
30-
#' cost = 1
26+
#' n <- 500
27+
#' K <- 1
28+
#' reward <- matrix(1 + rnorm(n * K), n, K)
29+
#' scores <- reward + matrix(rnorm(n * K), n, K)
30+
#' cost <- 1
3131
#'
3232
#' # Fit Qini curves.
3333
#' qini.avg <- maq(reward, cost, scores, R = 200, target.with.covariates = FALSE)

r-package/maq/man/maq_scale.Rd

Lines changed: 41 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

r-package/maq/man/plot.maq.Rd

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

r-package/maq/pkgdown/_pkgdown.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ reference:
1616
- title: Multi-Armed Qini
1717
contents:
1818
- maq
19+
- maq_scale
1920
- predict.maq
2021
- average_gain
2122
- difference_gain

r-package/maq/tests/testthat/test_maq.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,19 @@ test_that("integrated_difference grid lookup numerics works as expected", {
558558

559559
expect_equal(est, estt, tolerance = 1e-10)
560560
})
561+
562+
test_that("maq_scale works as expected", {
563+
n <- 500
564+
K <- 2
565+
reward <- matrix(1 + rnorm(n * K), n, K)
566+
scores <- reward + matrix(rnorm(n * K), n, K)
567+
cost <- matrix(1 + runif(n * K), n, K)
568+
569+
scale <- 100
570+
qini <- maq(reward, cost, scores, R = 200)
571+
qini.s <- maq(reward, cost * scale, scores * scale, R = 200)
572+
573+
expect_equal(qini.s, maq_scale(qini, scale))
574+
expect_equal(average_gain(qini.s, 42),
575+
average_gain(maq_scale(qini, scale), 42))
576+
})

0 commit comments

Comments
 (0)