Skip to content

Commit 3e4a4dd

Browse files
authored
Merge pull request #2453 from andreyfe1/conj_grad_fix
[oneMKL] Updated oneMKL API for Sparse Conjugate Gradient
2 parents 2ed1542 + d410c03 commit 3e4a4dd

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

Libraries/oneMKL/sparse_conjugate_gradient/sparse_cg.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ void run_sparse_cg_example(const sycl::device &dev)
138138
try {
139139
mkl::sparse::init_matrix_handle(&handle);
140140

141-
mkl::sparse::set_csr_data(handle, nrows, nrows, mkl::index_base::zero,
141+
mkl::sparse::set_csr_data(main_queue, handle, nrows, nrows, mkl::index_base::zero,
142142
ia_buffer, ja_buffer, a_buffer);
143143

144144
mkl::sparse::set_matrix_property(handle, mkl::sparse::property::symmetric);
@@ -174,13 +174,14 @@ void run_sparse_cg_example(const sycl::device &dev)
174174

175175
// Calculation B^{-1}r_0
176176
{
177+
fp alpha = 1.0;
177178
mkl::sparse::trsv(main_queue, mkl::uplo::lower,
178-
mkl::transpose::nontrans, mkl::diag::nonunit,
179-
handle, r_buffer, t_buffer);
179+
mkl::transpose::nontrans,
180+
mkl::diag::nonunit, alpha, handle, r_buffer, t_buffer);
180181
diagonal_mv<fp, intType>(main_queue, nrows, d_buffer, t_buffer);
181182
mkl::sparse::trsv(main_queue, mkl::uplo::upper,
182-
mkl::transpose::nontrans, mkl::diag::nonunit,
183-
handle, t_buffer, w_buffer);
183+
mkl::transpose::nontrans,
184+
mkl::diag::nonunit, alpha, handle, t_buffer, w_buffer);
184185
}
185186

186187
mkl::blas::copy(main_queue, nrows, w_buffer, 1, p_buffer, 1);
@@ -225,13 +226,14 @@ void run_sparse_cg_example(const sycl::device &dev)
225226

226227
// Calculate w_k = B^{-1}r_k
227228
{
229+
fp alpha = 1.0;
228230
mkl::sparse::trsv(main_queue, mkl::uplo::lower,
229231
mkl::transpose::nontrans,
230-
mkl::diag::nonunit, handle, r_buffer, t_buffer);
232+
mkl::diag::nonunit, alpha, handle, r_buffer, t_buffer);
231233
diagonal_mv<fp, intType>(main_queue, nrows, d_buffer, t_buffer);
232234
mkl::sparse::trsv(main_queue, mkl::uplo::upper,
233235
mkl::transpose::nontrans,
234-
mkl::diag::nonunit, handle, t_buffer, w_buffer);
236+
mkl::diag::nonunit, alpha, handle, t_buffer, w_buffer);
235237
}
236238

237239
// Calculate current norm of correction
@@ -271,8 +273,9 @@ void run_sparse_cg_example(const sycl::device &dev)
271273
catch (std::exception const &e) {
272274
std::cout << "\t\tCaught exception:\n" << e.what() << std::endl;
273275
}
274-
275-
mkl::sparse::release_matrix_handle(&handle);
276+
277+
mkl::sparse::release_matrix_handle(main_queue, &handle);
278+
main_queue.wait();
276279
}
277280

278281
//

0 commit comments

Comments
 (0)