Skip to content

Commit 658daf1

Browse files
authored
Improve iterative refinement (#677)
This PR adds GMRES based iterative refinement. The cholesky (or LDL) factorization is used as the preconditioning strategy for the GMRES iterations. For now, this is enabled only for QP problems. Additional: * GPU-accelerated augmented multiply * Native device-vector support for sparse matrix multiply Authors: - Rajesh Gandham (https://github.com/rg20) Approvers: - Chris Maes (https://github.com/chris-maes) URL: #677
1 parent a5b226c commit 658daf1

File tree

4 files changed

+278
-35
lines changed

4 files changed

+278
-35
lines changed

cpp/src/dual_simplex/barrier.cu

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,11 +1368,18 @@ class iteration_data_t {
13681368
}
13691369
}
13701370

1371+
template <typename T>
1372+
struct axpy_op {
1373+
T alpha;
1374+
T beta;
1375+
__host__ __device__ T operator()(T x, T y) const { return alpha * x + beta * y; }
1376+
};
1377+
13711378
// y <- alpha * Augmented * x + beta * y
13721379
void augmented_multiply(f_t alpha,
13731380
const dense_vector_t<i_t, f_t>& x,
13741381
f_t beta,
1375-
dense_vector_t<i_t, f_t>& y) const
1382+
dense_vector_t<i_t, f_t>& y)
13761383
{
13771384
const i_t m = A.m;
13781385
const i_t n = A.n;
@@ -1381,22 +1388,53 @@ class iteration_data_t {
13811388
dense_vector_t<i_t, f_t> y1 = y.head(n);
13821389
dense_vector_t<i_t, f_t> y2 = y.tail(m);
13831390

1391+
rmm::device_uvector<f_t> d_x1(n, handle_ptr->get_stream());
1392+
rmm::device_uvector<f_t> d_x2(m, handle_ptr->get_stream());
1393+
rmm::device_uvector<f_t> d_y1(n, handle_ptr->get_stream());
1394+
rmm::device_uvector<f_t> d_y2(m, handle_ptr->get_stream());
1395+
1396+
raft::copy(d_x1.data(), x1.data(), n, handle_ptr->get_stream());
1397+
raft::copy(d_x2.data(), x2.data(), m, handle_ptr->get_stream());
1398+
raft::copy(d_y1.data(), y1.data(), n, handle_ptr->get_stream());
1399+
raft::copy(d_y2.data(), y2.data(), m, handle_ptr->get_stream());
1400+
13841401
// y1 <- alpha ( -D * x_1 + A^T x_2) + beta * y1
1385-
dense_vector_t<i_t, f_t> r1(n);
1386-
diag.pairwise_product(x1, r1);
1387-
if (Q.n > 0) { matrix_vector_multiply(Q, 1.0, x1, 1.0, r1); }
1388-
y1.axpy(-alpha, r1, beta);
1389-
matrix_transpose_vector_multiply(A, alpha, x2, 1.0, y1);
13901402

1403+
rmm::device_uvector<f_t> d_r1(n, handle_ptr->get_stream());
1404+
1405+
// diag.pairwise_product(x1, r1);
1406+
// r1 <- D * x_1
1407+
thrust::transform(handle_ptr->get_thrust_policy(),
1408+
d_x1.data(),
1409+
d_x1.data() + n,
1410+
d_diag_.data(),
1411+
d_r1.data(),
1412+
thrust::multiplies<f_t>());
1413+
1414+
// r1 <- Q x1 + D x1
1415+
if (Q.n > 0) {
1416+
// matrix_vector_multiply(Q, 1.0, x1, 1.0, r1);
1417+
cusparse_Q_view_.spmv(1.0, d_x1, 1.0, d_r1);
1418+
}
1419+
1420+
// y1 <- - alpha * r1 + beta * y1
1421+
// y1.axpy(-alpha, r1, beta);
1422+
thrust::transform(handle_ptr->get_thrust_policy(),
1423+
d_r1.data(),
1424+
d_r1.data() + n,
1425+
d_y1.data(),
1426+
d_y1.data(),
1427+
axpy_op<f_t>{-alpha, beta});
1428+
1429+
// matrix_transpose_vector_multiply(A, alpha, x2, 1.0, y1);
1430+
cusparse_view_.transpose_spmv(alpha, d_x2, 1.0, d_y1);
13911431
// y2 <- alpha ( A*x) + beta * y2
1392-
matrix_vector_multiply(A, alpha, x1, beta, y2);
1432+
// matrix_vector_multiply(A, alpha, x1, beta, y2);
1433+
cusparse_view_.spmv(alpha, d_x1, beta, d_y2);
13931434

1394-
for (i_t i = 0; i < n; ++i) {
1395-
y[i] = y1[i];
1396-
}
1397-
for (i_t i = n; i < n + m; ++i) {
1398-
y[i] = y2[i - n];
1399-
}
1435+
raft::copy(y.data(), d_y1.data(), n, stream_view_);
1436+
raft::copy(y.data() + n, d_y2.data(), m, stream_view_);
1437+
handle_ptr->sync_stream();
14001438
}
14011439

14021440
raft::handle_t const* handle_ptr;
@@ -1711,8 +1749,8 @@ int barrier_solver_t<i_t, f_t>::initial_point(iteration_data_t<i_t, f_t>& data)
17111749
dense_vector_t<i_t, f_t> soln(lp.num_cols + lp.num_rows);
17121750
i_t solve_status = data.chol->solve(rhs, soln);
17131751
struct op_t {
1714-
op_t(const iteration_data_t<i_t, f_t>& data) : data_(data) {}
1715-
const iteration_data_t<i_t, f_t>& data_;
1752+
op_t(iteration_data_t<i_t, f_t>& data) : data_(data) {}
1753+
iteration_data_t<i_t, f_t>& data_;
17161754
void a_multiply(f_t alpha,
17171755
const dense_vector_t<i_t, f_t>& x,
17181756
f_t beta,
@@ -2410,12 +2448,12 @@ i_t barrier_solver_t<i_t, f_t>::gpu_compute_search_direction(iteration_data_t<i_
24102448
dense_vector_t<i_t, f_t> augmented_soln(lp.num_cols + lp.num_rows);
24112449
data.chol->solve(augmented_rhs, augmented_soln);
24122450
struct op_t {
2413-
op_t(const iteration_data_t<i_t, f_t>& data) : data_(data) {}
2414-
const iteration_data_t<i_t, f_t>& data_;
2451+
op_t(iteration_data_t<i_t, f_t>& data) : data_(data) {}
2452+
iteration_data_t<i_t, f_t>& data_;
24152453
void a_multiply(f_t alpha,
24162454
const dense_vector_t<i_t, f_t>& x,
24172455
f_t beta,
2418-
dense_vector_t<i_t, f_t>& y) const
2456+
dense_vector_t<i_t, f_t>& y)
24192457
{
24202458
data_.augmented_multiply(alpha, x, beta, y);
24212459
}

cpp/src/dual_simplex/cusparse_view.cu

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,21 @@ void cusparse_view_t<i_t, f_t>::spmv(f_t alpha,
268268
{
269269
auto d_x = device_copy(x, handle_ptr_->get_stream());
270270
auto d_y = device_copy(y, handle_ptr_->get_stream());
271-
detail::cusparse_dn_vec_descr_wrapper_t<f_t> x_cusparse = create_vector(d_x);
272-
detail::cusparse_dn_vec_descr_wrapper_t<f_t> y_cusparse = create_vector(d_y);
273-
spmv(alpha, x_cusparse, beta, y_cusparse);
271+
spmv(alpha, d_x, beta, d_y);
274272
y = cuopt::host_copy<f_t, AllocatorB>(d_y, handle_ptr_->get_stream());
275273
}
276274

275+
template <typename i_t, typename f_t>
276+
void cusparse_view_t<i_t, f_t>::spmv(f_t alpha,
277+
const rmm::device_uvector<f_t>& x,
278+
f_t beta,
279+
rmm::device_uvector<f_t>& y)
280+
{
281+
detail::cusparse_dn_vec_descr_wrapper_t<f_t> x_cusparse = create_vector(x);
282+
detail::cusparse_dn_vec_descr_wrapper_t<f_t> y_cusparse = create_vector(y);
283+
spmv(alpha, x_cusparse, beta, y_cusparse);
284+
}
285+
277286
template <typename i_t, typename f_t>
278287
void cusparse_view_t<i_t, f_t>::spmv(f_t alpha,
279288
detail::cusparse_dn_vec_descr_wrapper_t<f_t> const& x,
@@ -311,12 +320,21 @@ void cusparse_view_t<i_t, f_t>::transpose_spmv(f_t alpha,
311320
{
312321
auto d_x = device_copy(x, handle_ptr_->get_stream());
313322
auto d_y = device_copy(y, handle_ptr_->get_stream());
314-
detail::cusparse_dn_vec_descr_wrapper_t<f_t> x_cusparse = create_vector(d_x);
315-
detail::cusparse_dn_vec_descr_wrapper_t<f_t> y_cusparse = create_vector(d_y);
316-
transpose_spmv(alpha, x_cusparse, beta, y_cusparse);
323+
transpose_spmv(alpha, d_x, beta, d_y);
317324
y = cuopt::host_copy<f_t, AllocatorB>(d_y, handle_ptr_->get_stream());
318325
}
319326

327+
template <typename i_t, typename f_t>
328+
void cusparse_view_t<i_t, f_t>::transpose_spmv(f_t alpha,
329+
const rmm::device_uvector<f_t>& x,
330+
f_t beta,
331+
rmm::device_uvector<f_t>& y)
332+
{
333+
detail::cusparse_dn_vec_descr_wrapper_t<f_t> x_cusparse = create_vector(x);
334+
detail::cusparse_dn_vec_descr_wrapper_t<f_t> y_cusparse = create_vector(y);
335+
transpose_spmv(alpha, x_cusparse, beta, y_cusparse);
336+
}
337+
320338
template <typename i_t, typename f_t>
321339
void cusparse_view_t<i_t, f_t>::transpose_spmv(
322340
f_t alpha,

cpp/src/dual_simplex/cusparse_view.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class cusparse_view_t {
3636
const std::vector<f_t, AllocatorA>& x,
3737
f_t beta,
3838
std::vector<f_t, AllocatorB>& y);
39+
void spmv(f_t alpha, rmm::device_uvector<f_t> const& x, f_t beta, rmm::device_uvector<f_t>& y);
3940
void spmv(f_t alpha,
4041
detail::cusparse_dn_vec_descr_wrapper_t<f_t> const& x,
4142
f_t beta,
@@ -45,6 +46,10 @@ class cusparse_view_t {
4546
const std::vector<f_t, AllocatorA>& x,
4647
f_t beta,
4748
std::vector<f_t, AllocatorB>& y);
49+
void transpose_spmv(f_t alpha,
50+
rmm::device_uvector<f_t> const& x,
51+
f_t beta,
52+
rmm::device_uvector<f_t>& y);
4853
void transpose_spmv(f_t alpha,
4954
detail::cusparse_dn_vec_descr_wrapper_t<f_t> const& x,
5055
f_t beta,

0 commit comments

Comments
 (0)