Skip to content

Commit c0f19ee

Browse files
committed
Added support for new osqp version
1 parent c7f7042 commit c0f19ee

File tree

5 files changed

+52
-126
lines changed

5 files changed

+52
-126
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: WeightIt
22
Type: Package
33
Title: Weighting for Covariate Balance in Observational Studies
4-
Version: 1.5.1.9000
4+
Version: 1.5.1.9001
55
Authors@R: c(
66
person("Noah", "Greifer", role=c("aut", "cre"),
77
email = "noah.greifer@gmail.com",

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ WeightIt News and Updates
55

66
* Added `method = "cfd"` for characteristic function distance balancing as described by [Santra, Chen, and Park (2026)](http://arxiv.org/abs/2601.15449). Energy balancing is a special case of this method.
77

8+
* Added support for the new version of *osqp*, which changes some optional argument names and defaults for `method = "energy"`. These should not impact results.
9+
810
# `WeightIt` 1.5.1
911

1012
* `calibrate()` with `method = "isoreg"` can now be used with sampling weights.

R/functions_for_processing.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,49 @@
363363
missing
364364
}
365365

366+
.process_osqp_settings <- function(min.w, verbose, ...) {
367+
A <- ...mget(rlang::fn_fmls_names(osqp::osqpSettings))
368+
369+
eps <- ...get("eps", squish(min.w, lo = 1e-12, hi = 1e-8))
370+
if (is_not_null(eps)) {
371+
chk::chk_number(eps)
372+
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- eps
373+
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- eps
374+
}
375+
376+
if (is_null(A[["max_iter"]])) A[["max_iter"]] <- 5e4L
377+
chk::chk_count(A[["max_iter"]], "`max_iter`")
378+
chk::chk_lt(A[["max_iter"]], Inf, "`max_iter`")
379+
380+
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- 1e-8
381+
chk::chk_number(A[["eps_abs"]], "`eps_abs`")
382+
383+
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- 1e-6
384+
chk::chk_number(A[["eps_rel"]], "`eps_rel`")
385+
386+
if (is_null(A[["time_limit"]])) A[["time_limit"]] <- 0
387+
chk::chk_number(A[["time_limit"]], "`time_limit`")
388+
389+
if (is_null(A[["adaptive_rho_interval"]])) A[["adaptive_rho_interval"]] <- 50L
390+
chk::chk_count(A[["adaptive_rho_interval"]], "`adaptive_rho_interval`")
391+
392+
if (packageVersion("osqp") >= "1.0.0") {
393+
A[["polishing"]] <- ...get("polishing") %or% ...get("polish") %or% TRUE
394+
chk::chk_flag(A[["polishing"]], "`polishing`")
395+
}
396+
else {
397+
A[["polish"]] <- ...get("polish") %or% ...get("polishing") %or% TRUE
398+
chk::chk_flag(A[["polish"]], "`polish`")
399+
}
400+
401+
if (is_null(A[["polish_refine_iter"]])) A[["polish_refine_iter"]] <- 20L
402+
chk::chk_count(A[["polish_refine_iter"]], "`polish_refine_iter`")
403+
404+
A[["verbose"]] <- verbose
405+
406+
do.call(osqp::osqpSettings, A)
407+
}
408+
366409
.check_user_method <- function(method) {
367410
#Check to make sure it accepts treat and covs
368411
if (all(c("covs", "treat") %in% rlang::fn_fmls_names(method))) {

R/weightit2cfd.R

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -232,33 +232,7 @@ weightit2cfd <- function(covs, treat, s.weights, subset, estimand, focal,
232232
tols <- abs(tols)
233233
}
234234

235-
A <- ...mget(names(formals(osqp::osqpSettings)))
236-
237-
eps <- ...get("eps", squish(min.w, lo = 1e-12, hi = 1e-8))
238-
if (is_not_null(eps)) {
239-
chk::chk_number(eps)
240-
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- eps
241-
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- eps
242-
}
243-
244-
if (is_null(A[["max_iter"]])) A[["max_iter"]] <- 5e4L
245-
chk::chk_count(A[["max_iter"]], "`max_iter`")
246-
chk::chk_lt(A[["max_iter"]], Inf, "`max_iter`")
247-
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- 1e-8
248-
chk::chk_number(A[["eps_abs"]], "`eps_abs`")
249-
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- 1e-6
250-
chk::chk_number(A[["eps_rel"]], "`eps_rel`")
251-
if (is_null(A[["time_limit"]])) A[["time_limit"]] <- 0
252-
chk::chk_number(A[["time_limit"]], "`time_limit`")
253-
if (is_null(A[["adaptive_rho_interval"]])) A[["adaptive_rho_interval"]] <- 10L
254-
chk::chk_count(A[["adaptive_rho_interval"]], "`adaptive_rho_interval`")
255-
if (is_null(A[["polish"]])) A[["polish"]] <- TRUE
256-
chk::chk_flag(A[["polish"]], "`polish`")
257-
if (is_null(A[["polish_refine_iter"]])) A[["polish_refine_iter"]] <- 20L
258-
chk::chk_count(A[["polish_refine_iter"]], "`polish_refine_iter`")
259-
A[["verbose"]] <- verbose
260-
261-
options.list <- do.call(osqp::osqpSettings, A)
235+
options.list <- .process_osqp_settings(min.w, verbose, ...)
262236

263237
t0 <- which(treat == 0)
264238
t1 <- which(treat == 1)
@@ -531,33 +505,7 @@ weightit2cfd.multi <- function(covs, treat, s.weights, subset, estimand, focal,
531505
tols <- abs(tols)
532506
}
533507

534-
A <- ...mget(names(formals(osqp::osqpSettings)))
535-
536-
eps <- ...get("eps", squish(min.w, lo = 1e-12, hi = 1e-8))
537-
if (is_not_null(eps)) {
538-
chk::chk_number(eps)
539-
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- eps
540-
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- eps
541-
}
542-
543-
if (is_null(A[["max_iter"]])) A[["max_iter"]] <- 5e4L
544-
chk::chk_count(A[["max_iter"]], "`max_iter`")
545-
chk::chk_lt(A[["max_iter"]], Inf, "`max_iter`")
546-
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- 1e-8
547-
chk::chk_number(A[["eps_abs"]], "`eps_abs`")
548-
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- 1e-6
549-
chk::chk_number(A[["eps_rel"]], "`eps_rel`")
550-
if (is_null(A[["time_limit"]])) A[["time_limit"]] <- 0
551-
chk::chk_number(A[["time_limit"]], "`time_limit`")
552-
if (is_null(A[["adaptive_rho_interval"]])) A[["adaptive_rho_interval"]] <- 10L
553-
chk::chk_count(A[["adaptive_rho_interval"]], "`adaptive_rho_interval`")
554-
if (is_null(A[["polish"]])) A[["polish"]] <- TRUE
555-
chk::chk_flag(A[["polish"]], "`polish`")
556-
if (is_null(A[["polish_refine_iter"]])) A[["polish_refine_iter"]] <- 20L
557-
chk::chk_count(A[["polish_refine_iter"]], "`polish_refine_iter`")
558-
A[["verbose"]] <- verbose
559-
560-
options.list <- do.call(osqp::osqpSettings, A)
508+
options.list <- .process_osqp_settings(min.w, verbose, ...)
561509

562510
treat_t <- matrix(0, nrow = n, ncol = length(levels_treat),
563511
dimnames = list(NULL, levels_treat))

R/weightit2energy.R

Lines changed: 4 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -467,30 +467,7 @@ weightit2energy <- function(covs, treat, s.weights, subset, estimand, focal,
467467
ATC = diag(P) + lambda * s.weights_n_1[t1]^2 / 2)
468468
}
469469

470-
A <- ...mget(names(formals(osqp::osqpSettings)))
471-
472-
if (is_not_null(...get("eps"))) {
473-
chk::chk_number(...get("eps"), "`eps`")
474-
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- ...get("eps")
475-
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- ...get("eps")
476-
}
477-
478-
if (is_null(A[["max_iter"]])) A[["max_iter"]] <- 5e4L
479-
chk::chk_count(A[["max_iter"]], "`max_iter`")
480-
chk::chk_lt(A[["max_iter"]], Inf, "`max_iter`")
481-
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- 1e-8
482-
chk::chk_number(A[["eps_abs"]], "`eps_abs`")
483-
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- 1e-6
484-
chk::chk_number(A[["eps_rel"]], "`eps_rel`")
485-
if (is_null(A[["time_limit"]])) A[["time_limit"]] <- 0
486-
chk::chk_number(A[["time_limit"]], "`time_limit`")
487-
if (is_null(A[["adaptive_rho_interval"]])) A[["adaptive_rho_interval"]] <- 10L
488-
chk::chk_count(A[["adaptive_rho_interval"]], "`adaptive_rho_interval`")
489-
if (is_null(A[["polish"]])) A[["polish"]] <- TRUE
490-
chk::chk_flag(A[["polish"]], "`polish`")
491-
A[["verbose"]] <- TRUE
492-
493-
options.list <- do.call(osqp::osqpSettings, A)
470+
options.list <- .process_osqp_settings(min.w, verbose, ...)
494471

495472
verbosely({
496473
opt.out <- osqp::solve_osqp(P = 2 * P, q = q, A = t(Amat), l = lvec, u = uvec,
@@ -730,30 +707,7 @@ weightit2energy.multi <- function(covs, treat, s.weights, subset, estimand, foca
730707

731708
diag(P) <- diag(P) + lambda / n^2
732709

733-
A <- ...mget(names(formals(osqp::osqpSettings)))
734-
735-
if (is_not_null(...get("eps"))) {
736-
chk::chk_number(...get("eps"), "`eps`")
737-
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- ...get("eps")
738-
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- ...get("eps")
739-
}
740-
741-
if (is_null(A[["max_iter"]])) A[["max_iter"]] <- 5e4L
742-
chk::chk_count(A[["max_iter"]], "`max_iter`")
743-
chk::chk_lt(A[["max_iter"]], Inf, "`max_iter`")
744-
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- 1e-8
745-
chk::chk_number(A[["eps_abs"]], "`eps_abs`")
746-
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- 1e-6
747-
chk::chk_number(A[["eps_rel"]], "`eps_rel`")
748-
if (is_null(A[["time_limit"]])) A[["time_limit"]] <- 0
749-
chk::chk_number(A[["time_limit"]], "`time_limit`")
750-
if (is_null(A[["adaptive_rho_interval"]])) A[["adaptive_rho_interval"]] <- 10L
751-
chk::chk_count(A[["adaptive_rho_interval"]], "`adaptive_rho_interval`")
752-
if (is_null(A[["polish"]])) A[["polish"]] <- TRUE
753-
chk::chk_flag(A[["polish"]], "`polish`")
754-
A[["verbose"]] <- TRUE
755-
756-
options.list <- do.call(osqp::osqpSettings, A)
710+
options.list <- .process_osqp_settings(min.w, verbose, ...)
757711

758712
verbosely({
759713
opt.out <- osqp::solve_osqp(P = 2 * P, q = q, A = t(Amat), l = lvec, u = uvec,
@@ -866,6 +820,8 @@ weightit2energy.cont <- function(covs, treat, s.weights, subset, missing, verbos
866820
tols <- abs(tols)
867821
}
868822

823+
options.list <- .process_osqp_settings(min.w, verbose, ...)
824+
869825
d.moments <- max(...get("d.moments", 0), moments)
870826
chk::chk_count(d.moments)
871827

@@ -957,29 +913,6 @@ weightit2energy.cont <- function(covs, treat, s.weights, subset, missing, verbos
957913

958914
diag(P) <- diag(P) + lambda / n^2
959915

960-
A <- ...mget(names(formals(osqp::osqpSettings)))
961-
962-
eps <- ...get("eps", 1e-8)
963-
964-
chk::chk_number(eps)
965-
if (is_null(A[["eps_abs"]])) A[["eps_abs"]] <- eps
966-
if (is_null(A[["eps_rel"]])) A[["eps_rel"]] <- eps
967-
968-
if (is_null(A[["max_iter"]])) A[["max_iter"]] <- 5e4L
969-
chk::chk_count(A[["max_iter"]], "`max_iter`")
970-
chk::chk_lt(A[["max_iter"]], Inf, "`max_iter`")
971-
chk::chk_number(A[["eps_abs"]], "`eps_abs`")
972-
chk::chk_number(A[["eps_rel"]], "`eps_rel`")
973-
if (is_null(A[["time_limit"]])) A[["time_limit"]] <- 0
974-
chk::chk_number(A[["time_limit"]], "`time_limit`")
975-
if (is_null(A[["adaptive_rho_interval"]])) A[["adaptive_rho_interval"]] <- 10L
976-
chk::chk_count(A[["adaptive_rho_interval"]], "`adaptive_rho_interval`")
977-
if (is_null(A[["polish"]])) A[["polish"]] <- TRUE
978-
chk::chk_flag(A[["polish"]], "`polish`")
979-
A[["verbose"]] <- TRUE
980-
981-
options.list <- do.call(osqp::osqpSettings, A)
982-
983916
verbosely({
984917
opt.out <- osqp::solve_osqp(P = 2 * P, q = q, A = t(Amat), l = lvec, u = uvec,
985918
pars = options.list)

0 commit comments

Comments
 (0)