-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathadmm_bindings.cpp
More file actions
123 lines (107 loc) · 4.94 KB
/
admm_bindings.cpp
File metadata and controls
123 lines (107 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <pybind11/pybind11.h>
#include <pybind11/eigen.h>
#include <pybind11/numpy.h>
#include "admm.hpp"
namespace py = pybind11;
using admm::Options;
using admm::Mode;
// --------- kwargs -> Options ----------
static Options parse_options(const py::kwargs& kw) {
Options o;
if (kw.contains("alpha")) o.alpha = kw["alpha"].cast<double>();
if (kw.contains("rho")) o.rho = kw["rho"].cast<double>();
if (kw.contains("max_iter")) o.max_iter = kw["max_iter"].cast<int>();
if (kw.contains("primal_tol")) o.primal_tol = kw["primal_tol"].cast<double>();
if (kw.contains("dual_tol")) o.dual_tol = kw["dual_tol"].cast<double>();
if (kw.contains("rho_update_steps")) o.rho_update_steps = kw["rho_update_steps"].cast<int>();
if (kw.contains("rho_update_ratio")) o.rho_update_ratio = kw["rho_update_ratio"].cast<double>();
if (kw.contains("min_rho")) o.min_rho = kw["min_rho"].cast<double>();
if (kw.contains("max_rho")) o.max_rho = kw["max_rho"].cast<double>();
if (kw.contains("mode")) {
const std::string s = kw["mode"].cast<std::string>();
if (s == "ADMMslack" || s == "ADMMSLACK") {
o.mode = Mode::ADMMslack;
} else if (s == "StandardSlack" || s == "STANDARDSLACK") {
o.mode = Mode::StandardSlack;
} else {
o.mode = Mode::Standard; // default & "Standard"
}
}
return o;
}
// ---- Build Eigen::SparseMatrix from a SciPy sparse matrix (CSR/CSC/COO) ----
static Eigen::SparseMatrix<double> scipy_to_eigen_sparse(const py::object& spmat) {
// Require .tocoo()
if (!py::hasattr(spmat, "tocoo"))
throw std::runtime_error("Expected a SciPy sparse matrix (has no .tocoo()).");
py::object coo = spmat.attr("tocoo")();
py::array_t<long long> I = coo.attr("row").cast<py::array_t<long long>>();
py::array_t<long long> J = coo.attr("col").cast<py::array_t<long long>>();
py::array_t<double> V = coo.attr("data").cast<py::array_t<double>>();
auto shape = coo.attr("shape").cast<std::pair<py::ssize_t, py::ssize_t>>();
const int rows = static_cast<int>(shape.first);
const int cols = static_cast<int>(shape.second);
if (I.size() != J.size() || I.size() != V.size())
throw std::runtime_error("SciPy COO arrays must have the same length");
std::vector<Eigen::Triplet<double>> trips;
trips.reserve(static_cast<size_t>(I.size()));
auto iacc = I.unchecked<1>();
auto jacc = J.unchecked<1>();
auto vacc = V.unchecked<1>();
for (ssize_t k = 0; k < I.size(); ++k) {
long long r = iacc(k), c = jacc(k);
if (r < 0 || r >= rows || c < 0 || c >= cols)
throw std::runtime_error("SciPy COO index out of bounds");
trips.emplace_back(static_cast<int>(r), static_cast<int>(c), vacc(k));
}
Eigen::SparseMatrix<double> M(rows, cols);
M.setFromTriplets(trips.begin(), trips.end());
M.makeCompressed();
return M;
}
PYBIND11_MODULE(admm_core, m) {
m.doc() = "ADMM dense/sparse solvers (Eigen + pybind11)";
py::class_<admm::Result>(m, "Result")
.def_readonly("x", &admm::Result::x)
.def_readonly("iters", &admm::Result::iters)
.def_readonly("primal_inf_norm", &admm::Result::primal_inf_norm)
.def_readonly("dual_inf_norm", &admm::Result::dual_inf_norm)
.def_readonly("rho_out", &admm::Result::rho_out)
.def_readonly("converged", &admm::Result::converged);
// Dense: pass numpy arrays directly
m.def("solve_dense",
[](const Eigen::MatrixXd& Q,
const Eigen::MatrixXd& A,
const Eigen::VectorXd& q,
const Eigen::VectorXd& l,
const Eigen::VectorXd& u,
int n,
const Eigen::VectorXd& x0,
const py::kwargs& kw) {
auto opts = parse_options(kw);
return admm::solve_dense(Q, A, q, l, u, n, x0, opts);
},
py::arg("Q"), py::arg("A"), py::arg("q"),
py::arg("l"), py::arg("u"), py::arg("n"), py::arg("x0"),
"Dense ADMM solve. kwargs: alpha, rho, max_iter, primal_tol, dual_tol, "
"rho_update_steps, rho_update_ratio, min_rho, max_rho, mode");
// Sparse: accept SciPy CSR/CSC/COO directly
m.def("solve_sparse_scipy",
[](py::object Q_scipy,
py::object A_scipy,
const Eigen::VectorXd& q,
const Eigen::VectorXd& l,
const Eigen::VectorXd& u,
int n,
const Eigen::VectorXd& x0,
const py::kwargs& kw) {
auto opts = parse_options(kw);
Eigen::SparseMatrix<double> Q = scipy_to_eigen_sparse(Q_scipy);
Eigen::SparseMatrix<double> A = scipy_to_eigen_sparse(A_scipy);
return admm::solve_sparse(Q, A, q, l, u, n, x0, opts);
},
py::arg("Q"), py::arg("A"), py::arg("q"),
py::arg("l"), py::arg("u"), py::arg("n"), py::arg("x0"),
"Sparse ADMM solve from SciPy CSR/CSC/COO matrices."
);
}