@@ -1368,11 +1368,18 @@ class iteration_data_t {
13681368 }
13691369 }
13701370
1371+ template <typename T>
1372+ struct axpy_op {
1373+ T alpha;
1374+ T beta;
1375+ __host__ __device__ T operator ()(T x, T y) const { return alpha * x + beta * y; }
1376+ };
1377+
13711378 // y <- alpha * Augmented * x + beta * y
13721379 void augmented_multiply (f_t alpha,
13731380 const dense_vector_t <i_t , f_t >& x,
13741381 f_t beta,
1375- dense_vector_t <i_t , f_t >& y) const
1382+ dense_vector_t <i_t , f_t >& y)
13761383 {
13771384 const i_t m = A.m ;
13781385 const i_t n = A.n ;
@@ -1381,22 +1388,53 @@ class iteration_data_t {
13811388 dense_vector_t <i_t , f_t > y1 = y.head (n);
13821389 dense_vector_t <i_t , f_t > y2 = y.tail (m);
13831390
1391+ rmm::device_uvector<f_t > d_x1 (n, handle_ptr->get_stream ());
1392+ rmm::device_uvector<f_t > d_x2 (m, handle_ptr->get_stream ());
1393+ rmm::device_uvector<f_t > d_y1 (n, handle_ptr->get_stream ());
1394+ rmm::device_uvector<f_t > d_y2 (m, handle_ptr->get_stream ());
1395+
1396+ raft::copy (d_x1.data (), x1.data (), n, handle_ptr->get_stream ());
1397+ raft::copy (d_x2.data (), x2.data (), m, handle_ptr->get_stream ());
1398+ raft::copy (d_y1.data (), y1.data (), n, handle_ptr->get_stream ());
1399+ raft::copy (d_y2.data (), y2.data (), m, handle_ptr->get_stream ());
1400+
13841401 // y1 <- alpha ( -D * x_1 + A^T x_2) + beta * y1
1385- dense_vector_t <i_t , f_t > r1 (n);
1386- diag.pairwise_product (x1, r1);
1387- if (Q.n > 0 ) { matrix_vector_multiply (Q, 1.0 , x1, 1.0 , r1); }
1388- y1.axpy (-alpha, r1, beta);
1389- matrix_transpose_vector_multiply (A, alpha, x2, 1.0 , y1);
13901402
1403+ rmm::device_uvector<f_t > d_r1 (n, handle_ptr->get_stream ());
1404+
1405+ // diag.pairwise_product(x1, r1);
1406+ // r1 <- D * x_1
1407+ thrust::transform (handle_ptr->get_thrust_policy (),
1408+ d_x1.data (),
1409+ d_x1.data () + n,
1410+ d_diag_.data (),
1411+ d_r1.data (),
1412+ thrust::multiplies<f_t >());
1413+
1414+ // r1 <- Q x1 + D x1
1415+ if (Q.n > 0 ) {
1416+ // matrix_vector_multiply(Q, 1.0, x1, 1.0, r1);
1417+ cusparse_Q_view_.spmv (1.0 , d_x1, 1.0 , d_r1);
1418+ }
1419+
1420+ // y1 <- - alpha * r1 + beta * y1
1421+ // y1.axpy(-alpha, r1, beta);
1422+ thrust::transform (handle_ptr->get_thrust_policy (),
1423+ d_r1.data (),
1424+ d_r1.data () + n,
1425+ d_y1.data (),
1426+ d_y1.data (),
1427+ axpy_op<f_t >{-alpha, beta});
1428+
1429+ // matrix_transpose_vector_multiply(A, alpha, x2, 1.0, y1);
1430+ cusparse_view_.transpose_spmv (alpha, d_x2, 1.0 , d_y1);
13911431 // y2 <- alpha ( A*x) + beta * y2
1392- matrix_vector_multiply (A, alpha, x1, beta, y2);
1432+ // matrix_vector_multiply(A, alpha, x1, beta, y2);
1433+ cusparse_view_.spmv (alpha, d_x1, beta, d_y2);
13931434
1394- for (i_t i = 0 ; i < n; ++i) {
1395- y[i] = y1[i];
1396- }
1397- for (i_t i = n; i < n + m; ++i) {
1398- y[i] = y2[i - n];
1399- }
1435+ raft::copy (y.data (), d_y1.data (), n, stream_view_);
1436+ raft::copy (y.data () + n, d_y2.data (), m, stream_view_);
1437+ handle_ptr->sync_stream ();
14001438 }
14011439
14021440 raft::handle_t const * handle_ptr;
@@ -1711,8 +1749,8 @@ int barrier_solver_t<i_t, f_t>::initial_point(iteration_data_t<i_t, f_t>& data)
17111749 dense_vector_t <i_t , f_t > soln (lp.num_cols + lp.num_rows );
17121750 i_t solve_status = data.chol ->solve (rhs, soln);
17131751 struct op_t {
1714- op_t (const iteration_data_t <i_t , f_t >& data) : data_(data) {}
1715- const iteration_data_t <i_t , f_t >& data_;
1752+ op_t (iteration_data_t <i_t , f_t >& data) : data_(data) {}
1753+ iteration_data_t <i_t , f_t >& data_;
17161754 void a_multiply (f_t alpha,
17171755 const dense_vector_t <i_t , f_t >& x,
17181756 f_t beta,
@@ -2410,12 +2448,12 @@ i_t barrier_solver_t<i_t, f_t>::gpu_compute_search_direction(iteration_data_t<i_
24102448 dense_vector_t <i_t , f_t > augmented_soln (lp.num_cols + lp.num_rows );
24112449 data.chol ->solve (augmented_rhs, augmented_soln);
24122450 struct op_t {
2413- op_t (const iteration_data_t <i_t , f_t >& data) : data_(data) {}
2414- const iteration_data_t <i_t , f_t >& data_;
2451+ op_t (iteration_data_t <i_t , f_t >& data) : data_(data) {}
2452+ iteration_data_t <i_t , f_t >& data_;
24152453 void a_multiply (f_t alpha,
24162454 const dense_vector_t <i_t , f_t >& x,
24172455 f_t beta,
2418- dense_vector_t <i_t , f_t >& y) const
2456+ dense_vector_t <i_t , f_t >& y)
24192457 {
24202458 data_.augmented_multiply (alpha, x, beta, y);
24212459 }
0 commit comments