Skip to content

Commit 6b2d2fe

Browse files
Merge pull request arcaneframework#313 from arcaneframework/dev/gg-separate-matrix-and-rhs-transformation
Separate Matrix and RHS transformation for `AlephDoFLinearSystemImpl` et `HypreDoFLinearSystemImpl`
2 parents aa97496 + 02c7bf0 commit 6b2d2fe

File tree

4 files changed

+97
-33
lines changed

4 files changed

+97
-33
lines changed

femutils/AlephDoFLinearSystem.cc

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,9 @@ class AlephDoFLinearSystemImpl
199199
private:
200200

201201
AlephParams* _createAlephParam() const;
202-
void _fillMatrix();
202+
void _applyMatrixTransformationAndFillAlephMatrix();
203203
void _fillRHSVector();
204+
void _applyRHSTransformationAndFillAlephRHS();
204205
void _setMatrixValue(DoF row, DoF column, Real value)
205206
{
206207
if (m_do_print_filling)
@@ -265,7 +266,6 @@ _fillRowColumnEliminationInfos()
265266

266267
auto& dof_elimination_info = getEliminationInfo();
267268
auto& dof_elimination_value = getEliminationValue();
268-
auto& rhs_variable = rhsVariable();
269269

270270
for (const auto& rc_value : m_values_map) {
271271
RowColumn rc = rc_value.first;
@@ -283,7 +283,7 @@ _fillRowColumnEliminationInfos()
283283
/*---------------------------------------------------------------------------*/
284284

285285
void AlephDoFLinearSystemImpl::
286-
_fillMatrix()
286+
_applyMatrixTransformationAndFillAlephMatrix()
287287
{
288288
_fillRowColumnEliminationInfos();
289289
OrderedRowColumnMap& rc_elimination_map = _rowColumnEliminationMap();
@@ -293,7 +293,6 @@ _fillMatrix()
293293

294294
auto& dof_elimination_info = getEliminationInfo();
295295
auto& dof_elimination_value = getEliminationValue();
296-
auto& rhs_variable = rhsVariable();
297296

298297
// Fill the matrix from the values of \a m_values_map
299298
// Skip (row,column) values which are part of an elimination.
@@ -322,14 +321,44 @@ _fillMatrix()
322321
_setMatrixValue(dof_row, dof_column, value);
323322
}
324323

324+
const bool do_print_filling = m_do_print_filling;
325+
326+
// Apply Row or Row+Column elimination on Matrix
327+
// Phase 2: set the diagonal value for elimination row to 1.0
328+
ENUMERATE_ (DoF, idof, dof_family->allItems()) {
329+
DoF dof = *idof;
330+
if (!dof.isOwn())
331+
continue;
332+
Byte elimination_info = dof_elimination_info[dof];
333+
if (elimination_info == ELIMINATE_ROW || elimination_info == ELIMINATE_ROW_COLUMN) {
334+
Real elimination_value = dof_elimination_value[dof];
335+
if (do_print_filling)
336+
info() << "EliminateMatrix info=" << static_cast<int>(elimination_info) << " row="
337+
<< std::setw(4) << dof.localId() << " value=" << elimination_value;
338+
_setMatrixValue(dof, dof, 1.0);
339+
}
340+
}
341+
}
342+
/*---------------------------------------------------------------------------*/
343+
/*---------------------------------------------------------------------------*/
344+
345+
void AlephDoFLinearSystemImpl::
346+
_applyRHSTransformationAndFillAlephRHS()
347+
{
348+
const bool do_print_filling = m_do_print_filling;
349+
325350
// Apply Row+Column elimination
326351
// Phase 1:
327352
// - subtract values of the RHS vector if Row+Column elimination
328-
_applyRowColumnEliminationToRHS(m_do_print_filling);
353+
_applyRowColumnEliminationToRHS(do_print_filling);
354+
355+
IItemFamily* dof_family = dofFamily();
356+
357+
auto& dof_elimination_info = getEliminationInfo();
358+
auto& dof_elimination_value = getEliminationValue();
359+
auto& rhs_variable = rhsVariable();
329360

330-
// Apply Row or Row+Column elimination
331-
// Phase 2: set the value of the RHS
332-
// Phase 2: fill the diagonal with 1.0
361+
// Apply Row or Row+Column elimination on RHS
333362
ENUMERATE_ (DoF, idof, dof_family->allItems()) {
334363
DoF dof = *idof;
335364
if (!dof.isOwn())
@@ -338,11 +367,13 @@ _fillMatrix()
338367
if (elimination_info == ELIMINATE_ROW || elimination_info == ELIMINATE_ROW_COLUMN) {
339368
Real elimination_value = dof_elimination_value[dof];
340369
rhs_variable[dof] = elimination_value;
341-
info() << "Eliminate info=" << static_cast<int>(elimination_info) << " row="
342-
<< std::setw(4) << dof.localId() << " value=" << elimination_value;
343-
_setMatrixValue(dof, dof, 1.0);
370+
if (do_print_filling)
371+
info() << "EliminateRHS info=" << static_cast<int>(elimination_info) << " row="
372+
<< std::setw(4) << dof.localId() << " value=" << elimination_value;
344373
}
345374
}
375+
376+
_fillRHSVector();
346377
}
347378

348379
/*---------------------------------------------------------------------------*/
@@ -468,15 +499,16 @@ solve()
468499
{
469500
UniqueArray<Real> aleph_result;
470501

471-
// _fillMatrix() may change the values of RHS vector
472-
// with row or row-column elimination so we have to fill the RHS vector
473-
// before the matrix.
474-
_fillMatrix();
475-
_fillRHSVector();
476-
477502
info() << "[AlephFem] Assemble matrix ptr=" << m_aleph_matrix;
503+
504+
// Matrix transformation
505+
_applyMatrixTransformationAndFillAlephMatrix();
478506
m_aleph_matrix->assemble();
507+
508+
// RHS Transformation
509+
_applyRHSTransformationAndFillAlephRHS();
479510
m_aleph_rhs_vector->assemble();
511+
480512
auto* aleph_solution_vector = m_aleph_solution_vector;
481513
IItemFamily* dof_family = dofFamily();
482514
DoFGroup own_dofs = dof_family->allItems().own();

femutils/CsrDoFLinearSystemImpl.cc

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,18 @@ _applyForcedValuesToLhs()
7676
/*---------------------------------------------------------------------------*/
7777

7878
void CsrDoFLinearSystemImpl::
79-
_applyRowOrRowColumnElimination()
79+
_applyRowOrRowColumnEliminationOnMatrix()
8080
{
81-
_applyRowElimination();
81+
_applyRowEliminationOnMatrix();
8282
if (m_has_row_column_elimination)
83-
_applyRowColumnElimination();
83+
_applyRowColumnEliminationOnMatrix();
8484
}
8585

8686
/*---------------------------------------------------------------------------*/
8787
/*---------------------------------------------------------------------------*/
8888

8989
void CsrDoFLinearSystemImpl::
90-
_applyRowColumnElimination()
90+
_applyRowColumnEliminationOnMatrix()
9191
{
9292
IItemFamily* dof_family = dofFamily();
9393

@@ -118,18 +118,14 @@ _applyRowColumnElimination()
118118
}
119119
}
120120
}
121-
if (is_row_elimination) {
122-
auto elimination_value = in_elimination_value[dof_row];
123-
in_out_rhs_variable[dof_row] = elimination_value;
124-
}
125121
};
126122
}
127123

128124
/*---------------------------------------------------------------------------*/
129125
/*---------------------------------------------------------------------------*/
130126

131127
void CsrDoFLinearSystemImpl::
132-
_applyRowElimination()
128+
_applyRowEliminationOnMatrix()
133129
{
134130
IItemFamily* dof_family = dofFamily();
135131

@@ -152,7 +148,37 @@ _applyRowElimination()
152148
auto elimination_value = in_elimination_value[dof_id];
153149
for (CsrRowColumnIndex csr_index : csr_view.rowRange(dof_id))
154150
csr_view.value(csr_index) = (csr_view.column(csr_index) == dof_id) ? 1.0 : 0.0;
155-
in_out_rhs_variable[dof_id] = elimination_value;
151+
}
152+
};
153+
}
154+
155+
/*---------------------------------------------------------------------------*/
156+
/*---------------------------------------------------------------------------*/
157+
158+
void CsrDoFLinearSystemImpl::
159+
_applyRowOrRowColumnEliminationOnRHS()
160+
{
161+
IItemFamily* dof_family = dofFamily();
162+
163+
auto nb_dof = dof_family->nbItem();
164+
165+
RunQueue queue = makeQueue(runner());
166+
auto command = makeCommand(queue);
167+
168+
auto in_elimination_info = Accelerator::viewIn(command, getEliminationInfo());
169+
auto in_elimination_value = Accelerator::viewIn(command, getEliminationValue());
170+
171+
auto in_out_rhs_variable = Accelerator::viewInOut(command, rhsVariable());
172+
auto csr_view = m_csr_view;
173+
command << RUNCOMMAND_LOOP1(iter, nb_dof)
174+
{
175+
auto [row_index] = iter();
176+
DoFLocalId dof_row(row_index);
177+
auto row_elimination_info = in_elimination_info[dof_row];
178+
bool is_row_elimination = (row_elimination_info == ELIMINATE_ROW) || (row_elimination_info == ELIMINATE_ROW_COLUMN);
179+
if (is_row_elimination) {
180+
auto elimination_value = in_elimination_value[dof_row];
181+
in_out_rhs_variable[dof_row] = elimination_value;
156182
}
157183
};
158184
}

femutils/HypreDoFLinearSystem.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,17 @@ namespace
258258
void HypreDoFLinearSystemImpl::
259259
solve()
260260
{
261+
const bool do_debug_print = false;
262+
263+
// Matrix transformation
261264
_fillRowColumnEliminationInfos();
262-
_applyRowColumnEliminationToRHS(true);
263-
_applyRowOrRowColumnElimination();
265+
_applyRowOrRowColumnEliminationOnMatrix();
264266
_applyForcedValuesToLhs();
265267

268+
// RHS transformation
269+
_applyRowColumnEliminationToRHS(do_debug_print);
270+
_applyRowOrRowColumnEliminationOnRHS();
271+
266272
#if HYPRE_RELEASE_NUMBER >= 22700
267273
HYPRE_MemoryLocation hypre_memory = HYPRE_MEMORY_HOST;
268274
HYPRE_ExecutionPolicy hypre_exec_policy = HYPRE_EXEC_HOST;
@@ -343,7 +349,6 @@ solve()
343349
HYPRE_IJMatrix ij_A = nullptr;
344350
HYPRE_ParCSRMatrix parcsr_A = nullptr;
345351

346-
const bool do_debug_print = false;
347352
const bool do_dump_matrix = false;
348353

349354
Span<const Int32> rows_index_span = m_dof_matrix_numbering.asArray();

femutils/internal/CsrDoFLinearSystemImpl.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,13 @@ class CsrDoFLinearSystemImpl
100100
// These methods should be private but has to be public because of NVidia compiler
101101
void _applyForcedValuesToLhs();
102102
void _fillRowColumnEliminationInfos();
103-
void _applyRowElimination();
104-
void _applyRowColumnElimination();
103+
void _applyRowEliminationOnMatrix();
104+
void _applyRowColumnEliminationOnMatrix();
105+
void _applyRowOrRowColumnEliminationOnRHS();
105106

106107
protected:
107108

108-
void _applyRowOrRowColumnElimination();
109+
void _applyRowOrRowColumnEliminationOnMatrix();
109110
};
110111

111112
/*---------------------------------------------------------------------------*/

0 commit comments

Comments
 (0)