Skip to content

Commit 69f1ac8

Browse files
authored
SpMatrix: Add constructor for CSR format (#4316)
1 parent 12d0882 commit 69f1ac8

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

Src/LinearSolvers/AMReX_SpMatrix.H

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ public:
3434

3535
void define (AlgPartition partition, int nnz);
3636

37+
//! Define a matrix with CSR format data. Note that mat and col_index
38+
//! should contains nnz elements. The number of elements in row_index
39+
//! should the numbrer of local rows plus 1. The data can be freed after
40+
//! this function call. For GPU builds, the data are expected to be in
41+
//! GPU memory.
42+
void define (AlgPartition partition, T const* mat, Long const* col_index,
43+
Long nnz, Long const* row_index);
44+
3745
[[nodiscard]] AlgPartition const& partition () const { return m_partition; }
3846

3947
[[nodiscard]] Long numLocalRows () const { return m_row_end - m_row_begin; }
@@ -160,6 +168,28 @@ SpMatrix<T,Allocator>::define_doit (int nnz)
160168
});
161169
}
162170

171+
template <typename T, template<typename> class Allocator>
172+
void
173+
SpMatrix<T,Allocator>::define (AlgPartition partition, T const* mat,
174+
Long const* col_index, Long nnz,
175+
Long const* row_index)
176+
{
177+
m_partition = std::move(partition);
178+
m_row_begin = m_partition[ParallelDescriptor::MyProc()];
179+
m_row_end = m_partition[ParallelDescriptor::MyProc()+1];
180+
Long nlocalrows = this->numLocalRows();
181+
m_data.mat.resize(nnz);
182+
m_data.col_index.resize(nnz);
183+
m_data.row_offset.resize(nlocalrows+1);
184+
m_data.nnz = nnz;
185+
Gpu::copyAsync(Gpu::deviceToDevice, mat, mat+nnz, m_data.mat.begin());
186+
Gpu::copyAsync(Gpu::deviceToDevice, col_index, col_index+nnz,
187+
m_data.col_index.begin());
188+
Gpu::copyAsync(Gpu::deviceToDevice, row_index, row_index+nlocalrows+1,
189+
m_data.row_index.begin());
190+
Gpu::streamSynchronize();
191+
}
192+
163193
template <typename T, template<typename> class Allocator>
164194
void
165195
SpMatrix<T,Allocator>::printToFile (std::string const& file) const

0 commit comments

Comments
 (0)