Skip to content

Commit 7ed353f

Browse files
authored
Merge pull request #682 from NVIDIA/release/25.12
Forward-merge release/25.12 into main
2 parents 1b38dcd + 658daf1 commit 7ed353f

File tree

4 files changed

+278
-35
lines changed

4 files changed

+278
-35
lines changed

cpp/src/dual_simplex/barrier.cu

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,11 +1368,18 @@ class iteration_data_t {
13681368
}
13691369
}
13701370

1371+
template <typename T>
1372+
struct axpy_op {
1373+
T alpha;
1374+
T beta;
1375+
__host__ __device__ T operator()(T x, T y) const { return alpha * x + beta * y; }
1376+
};
1377+
13711378
// y <- alpha * Augmented * x + beta * y
13721379
void augmented_multiply(f_t alpha,
13731380
const dense_vector_t<i_t, f_t>& x,
13741381
f_t beta,
1375-
dense_vector_t<i_t, f_t>& y) const
1382+
dense_vector_t<i_t, f_t>& y)
13761383
{
13771384
const i_t m = A.m;
13781385
const i_t n = A.n;
@@ -1381,22 +1388,53 @@ class iteration_data_t {
13811388
dense_vector_t<i_t, f_t> y1 = y.head(n);
13821389
dense_vector_t<i_t, f_t> y2 = y.tail(m);
13831390

1391+
rmm::device_uvector<f_t> d_x1(n, handle_ptr->get_stream());
1392+
rmm::device_uvector<f_t> d_x2(m, handle_ptr->get_stream());
1393+
rmm::device_uvector<f_t> d_y1(n, handle_ptr->get_stream());
1394+
rmm::device_uvector<f_t> d_y2(m, handle_ptr->get_stream());
1395+
1396+
raft::copy(d_x1.data(), x1.data(), n, handle_ptr->get_stream());
1397+
raft::copy(d_x2.data(), x2.data(), m, handle_ptr->get_stream());
1398+
raft::copy(d_y1.data(), y1.data(), n, handle_ptr->get_stream());
1399+
raft::copy(d_y2.data(), y2.data(), m, handle_ptr->get_stream());
1400+
13841401
// y1 <- alpha ( -D * x_1 + A^T x_2) + beta * y1
1385-
dense_vector_t<i_t, f_t> r1(n);
1386-
diag.pairwise_product(x1, r1);
1387-
if (Q.n > 0) { matrix_vector_multiply(Q, 1.0, x1, 1.0, r1); }
1388-
y1.axpy(-alpha, r1, beta);
1389-
matrix_transpose_vector_multiply(A, alpha, x2, 1.0, y1);
13901402

1403+
rmm::device_uvector<f_t> d_r1(n, handle_ptr->get_stream());
1404+
1405+
// diag.pairwise_product(x1, r1);
1406+
// r1 <- D * x_1
1407+
thrust::transform(handle_ptr->get_thrust_policy(),
1408+
d_x1.data(),
1409+
d_x1.data() + n,
1410+
d_diag_.data(),
1411+
d_r1.data(),
1412+
thrust::multiplies<f_t>());
1413+
1414+
// r1 <- Q x1 + D x1
1415+
if (Q.n > 0) {
1416+
// matrix_vector_multiply(Q, 1.0, x1, 1.0, r1);
1417+
cusparse_Q_view_.spmv(1.0, d_x1, 1.0, d_r1);
1418+
}
1419+
1420+
// y1 <- - alpha * r1 + beta * y1
1421+
// y1.axpy(-alpha, r1, beta);
1422+
thrust::transform(handle_ptr->get_thrust_policy(),
1423+
d_r1.data(),
1424+
d_r1.data() + n,
1425+
d_y1.data(),
1426+
d_y1.data(),
1427+
axpy_op<f_t>{-alpha, beta});
1428+
1429+
// matrix_transpose_vector_multiply(A, alpha, x2, 1.0, y1);
1430+
cusparse_view_.transpose_spmv(alpha, d_x2, 1.0, d_y1);
13911431
// y2 <- alpha ( A*x) + beta * y2
1392-
matrix_vector_multiply(A, alpha, x1, beta, y2);
1432+
// matrix_vector_multiply(A, alpha, x1, beta, y2);
1433+
cusparse_view_.spmv(alpha, d_x1, beta, d_y2);
13931434

1394-
for (i_t i = 0; i < n; ++i) {
1395-
y[i] = y1[i];
1396-
}
1397-
for (i_t i = n; i < n + m; ++i) {
1398-
y[i] = y2[i - n];
1399-
}
1435+
raft::copy(y.data(), d_y1.data(), n, stream_view_);
1436+
raft::copy(y.data() + n, d_y2.data(), m, stream_view_);
1437+
handle_ptr->sync_stream();
14001438
}
14011439

14021440
raft::handle_t const* handle_ptr;
@@ -1711,8 +1749,8 @@ int barrier_solver_t<i_t, f_t>::initial_point(iteration_data_t<i_t, f_t>& data)
17111749
dense_vector_t<i_t, f_t> soln(lp.num_cols + lp.num_rows);
17121750
i_t solve_status = data.chol->solve(rhs, soln);
17131751
struct op_t {
1714-
op_t(const iteration_data_t<i_t, f_t>& data) : data_(data) {}
1715-
const iteration_data_t<i_t, f_t>& data_;
1752+
op_t(iteration_data_t<i_t, f_t>& data) : data_(data) {}
1753+
iteration_data_t<i_t, f_t>& data_;
17161754
void a_multiply(f_t alpha,
17171755
const dense_vector_t<i_t, f_t>& x,
17181756
f_t beta,
@@ -2410,12 +2448,12 @@ i_t barrier_solver_t<i_t, f_t>::gpu_compute_search_direction(iteration_data_t<i_
24102448
dense_vector_t<i_t, f_t> augmented_soln(lp.num_cols + lp.num_rows);
24112449
data.chol->solve(augmented_rhs, augmented_soln);
24122450
struct op_t {
2413-
op_t(const iteration_data_t<i_t, f_t>& data) : data_(data) {}
2414-
const iteration_data_t<i_t, f_t>& data_;
2451+
op_t(iteration_data_t<i_t, f_t>& data) : data_(data) {}
2452+
iteration_data_t<i_t, f_t>& data_;
24152453
void a_multiply(f_t alpha,
24162454
const dense_vector_t<i_t, f_t>& x,
24172455
f_t beta,
2418-
dense_vector_t<i_t, f_t>& y) const
2456+
dense_vector_t<i_t, f_t>& y)
24192457
{
24202458
data_.augmented_multiply(alpha, x, beta, y);
24212459
}

cpp/src/dual_simplex/cusparse_view.cu

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,21 @@ void cusparse_view_t<i_t, f_t>::spmv(f_t alpha,
268268
{
269269
auto d_x = device_copy(x, handle_ptr_->get_stream());
270270
auto d_y = device_copy(y, handle_ptr_->get_stream());
271-
detail::cusparse_dn_vec_descr_wrapper_t<f_t> x_cusparse = create_vector(d_x);
272-
detail::cusparse_dn_vec_descr_wrapper_t<f_t> y_cusparse = create_vector(d_y);
273-
spmv(alpha, x_cusparse, beta, y_cusparse);
271+
spmv(alpha, d_x, beta, d_y);
274272
y = cuopt::host_copy<f_t, AllocatorB>(d_y, handle_ptr_->get_stream());
275273
}
276274

275+
template <typename i_t, typename f_t>
276+
void cusparse_view_t<i_t, f_t>::spmv(f_t alpha,
277+
const rmm::device_uvector<f_t>& x,
278+
f_t beta,
279+
rmm::device_uvector<f_t>& y)
280+
{
281+
detail::cusparse_dn_vec_descr_wrapper_t<f_t> x_cusparse = create_vector(x);
282+
detail::cusparse_dn_vec_descr_wrapper_t<f_t> y_cusparse = create_vector(y);
283+
spmv(alpha, x_cusparse, beta, y_cusparse);
284+
}
285+
277286
template <typename i_t, typename f_t>
278287
void cusparse_view_t<i_t, f_t>::spmv(f_t alpha,
279288
detail::cusparse_dn_vec_descr_wrapper_t<f_t> const& x,
@@ -311,12 +320,21 @@ void cusparse_view_t<i_t, f_t>::transpose_spmv(f_t alpha,
311320
{
312321
auto d_x = device_copy(x, handle_ptr_->get_stream());
313322
auto d_y = device_copy(y, handle_ptr_->get_stream());
314-
detail::cusparse_dn_vec_descr_wrapper_t<f_t> x_cusparse = create_vector(d_x);
315-
detail::cusparse_dn_vec_descr_wrapper_t<f_t> y_cusparse = create_vector(d_y);
316-
transpose_spmv(alpha, x_cusparse, beta, y_cusparse);
323+
transpose_spmv(alpha, d_x, beta, d_y);
317324
y = cuopt::host_copy<f_t, AllocatorB>(d_y, handle_ptr_->get_stream());
318325
}
319326

327+
template <typename i_t, typename f_t>
328+
void cusparse_view_t<i_t, f_t>::transpose_spmv(f_t alpha,
329+
const rmm::device_uvector<f_t>& x,
330+
f_t beta,
331+
rmm::device_uvector<f_t>& y)
332+
{
333+
detail::cusparse_dn_vec_descr_wrapper_t<f_t> x_cusparse = create_vector(x);
334+
detail::cusparse_dn_vec_descr_wrapper_t<f_t> y_cusparse = create_vector(y);
335+
transpose_spmv(alpha, x_cusparse, beta, y_cusparse);
336+
}
337+
320338
template <typename i_t, typename f_t>
321339
void cusparse_view_t<i_t, f_t>::transpose_spmv(
322340
f_t alpha,

cpp/src/dual_simplex/cusparse_view.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class cusparse_view_t {
3636
const std::vector<f_t, AllocatorA>& x,
3737
f_t beta,
3838
std::vector<f_t, AllocatorB>& y);
39+
void spmv(f_t alpha, rmm::device_uvector<f_t> const& x, f_t beta, rmm::device_uvector<f_t>& y);
3940
void spmv(f_t alpha,
4041
detail::cusparse_dn_vec_descr_wrapper_t<f_t> const& x,
4142
f_t beta,
@@ -45,6 +46,10 @@ class cusparse_view_t {
4546
const std::vector<f_t, AllocatorA>& x,
4647
f_t beta,
4748
std::vector<f_t, AllocatorB>& y);
49+
void transpose_spmv(f_t alpha,
50+
rmm::device_uvector<f_t> const& x,
51+
f_t beta,
52+
rmm::device_uvector<f_t>& y);
4853
void transpose_spmv(f_t alpha,
4954
detail::cusparse_dn_vec_descr_wrapper_t<f_t> const& x,
5055
f_t beta,

0 commit comments

Comments
 (0)