Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion femutils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ option(ENABLE_DEBUG_MATRIX "Enable Debug matrix instead of a sparse one" OFF)
set(ACCELERATOR_SOURCES
ArcaneFemFunctionsGpu.cc
CsrDoFLinearSystemImpl.cc
CsrFormatMatrix.cc
BSRFormat.cc
)

Expand All @@ -27,7 +28,6 @@ add_library(FemUtils
DoFLinearSystem.cc
CooFormatMatrix.h
CsrFormatMatrix.h
CsrFormatMatrix.cc
CsrFormatMatrixView.h
CsrFormatMatrixView.cc
BSRFormat.h
Expand Down
44 changes: 33 additions & 11 deletions femutils/CsrFormatMatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <arcane/core/IItemFamily.h>

#include <arcane/accelerator/core/RunQueue.h>
#include <arcane/accelerator/NumArrayViews.h>
#include <arcane/accelerator/RunCommandLoop.h>

#include "CsrFormatMatrix.h"
#include "DoFLinearSystem.h"
Expand Down Expand Up @@ -58,27 +60,45 @@ initialize(IItemFamily* dof_family, Int32 nnz, Int32 nbRow, RunQueue& queue)
void CsrFormat::
translateToLinearSystem(DoFLinearSystem& linear_system, const RunQueue& queue)
{
info() << "TranslateToLinearSystem this=" << this;
bool do_set_csr = linear_system.hasSetCSRValues();
info() << "TranslateToLinearSystem this=" << this << " is_csr=" << do_set_csr;

const Int32 nb_row = m_matrix_row.dim1Size();
const Int32 matrix_column_size = m_matrix_column.dim1Size();

// When using CSR format, we need to know the number of non zero values for
// each row.
// NOTE: it should be possible to compute that in setCoordinates().
// and this value is constant if the structure of the matrix do not change
// so we can store these values instead of recomputing them.
if (do_set_csr) {
m_matrix_rows_nb_column.resize(m_matrix_row.extent0());
//m_matrix_rows_nb_column.fill(0);
m_matrix_rows_nb_column.resize(nb_row);
auto command = makeCommand(queue);
auto out_matrix_rows_nb_column = viewOut(command, m_matrix_rows_nb_column);
auto in_matrix_rows = viewIn(command, m_matrix_row);
command << RUNCOMMAND_LOOP1(iter, nb_row)
{
auto [i] = iter();
Int32 nb_column = 0;
if (((i + 1) < nb_row) && (in_matrix_rows(i) == in_matrix_rows(i + 1))) {
out_matrix_rows_nb_column[0];
return;
}
for (Int32 j = in_matrix_rows(i); ((i + 1) < nb_row && j < in_matrix_rows(i + 1)) || ((i + 1) == nb_row && j < matrix_column_size); j++) {
++nb_column;
}
out_matrix_rows_nb_column[i] = nb_column;
};
CSRFormatView csr_view(view());
linear_system.setCSRValues(csr_view);
return;
}
Int32 nb_row = m_matrix_row.dim1Size();

for (Int32 i = 0; i < nb_row; i++) {
m_matrix_rows_nb_column[i] = 0;
if (((i + 1) < nb_row) && (m_matrix_row(i) == m_matrix_row(i + 1)))
continue;
for (Int32 j = m_matrix_row(i); ((i + 1) < nb_row && j < m_matrix_row(i + 1)) || ((i + 1) == nb_row && j < m_matrix_column.dim1Size()); j++) {
if (do_set_csr) {
++m_matrix_rows_nb_column[i];
continue;
}
for (Int32 j = m_matrix_row(i); ((i + 1) < nb_row && j < m_matrix_row(i + 1)) || ((i + 1) == nb_row && j < matrix_column_size); j++) {
if (DoFLocalId(m_matrix_column(j)).isNull())
continue;
//info() << "Add: (" << i << ", " << m_matrix_column(j) << " v=" << m_matrix_value(j);
Expand All @@ -87,10 +107,12 @@ translateToLinearSystem(DoFLinearSystem& linear_system, const RunQueue& queue)
}

if (do_set_csr) {
CSRFormatView csr_view(view());
linear_system.setCSRValues(csr_view);
}
}

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

CsrFormatMatrixView CsrFormat::
view()
{
Expand Down