Skip to content

Commit 636129e

Browse files
Merge pull request #14194 from KratosMultiphysics/core/csr_utilities
[FastPR][Core] Enabling getting indices from connectivities array
2 parents 386eede + 90f4fcc commit 636129e

File tree

2 files changed

+88
-8
lines changed

2 files changed

+88
-8
lines changed

kratos/python/add_sparse_matrices_to_python.cpp

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,62 @@ namespace
4444

4545
using EquationIdVectorType = std::vector<std::size_t>;
4646

47+
/**
48+
* @brief Get the Equation Id Csr Indices object
49+
* @details This function computes the CSR matrix value vector indices corresponding to the equation ids of the entities in the provided container.
50+
* @tparam TContainerType Type of the container of the entities.
51+
* @param rCsrMatrix The CSR matrix to get the indices from.
52+
* @param rConnectivities The connectivities of the entities.
53+
* @return NDData The NDData object containing the indices.
54+
*/
55+
NDData<int> GetEquationIdCsrIndices(
56+
const CsrMatrixType& rCsrMatrix,
57+
const NDData<int>& rConnectivities)
58+
{
59+
// Get shapes from the input connectivities
60+
const auto& r_shape = rConnectivities.Shape();
61+
KRATOS_ERROR_IF(r_shape.size() != 2) << "Input connectivities must have shape (n_entities, local_size)" << std::endl;
62+
const std::size_t n_entities = r_shape[0];
63+
const std::size_t local_size = r_shape[1];
64+
65+
// Assign the input NDData to have shape: number of entities * local_size * local_size
66+
DenseVector<unsigned int> nd_data_shape(3);
67+
nd_data_shape[0] = n_entities;
68+
nd_data_shape[1] = local_size;
69+
nd_data_shape[2] = local_size;
70+
auto eq_ids_data = NDData<int>(nd_data_shape);
71+
72+
// Loop over the connectivities
73+
auto eq_ids_data_view = eq_ids_data.ViewData();
74+
auto connectivities_view = rConnectivities.ViewData();
75+
IndexPartition<std::size_t>(n_entities).for_each([&](std::size_t i) {
76+
// Get current entity position in the connectivities view
77+
const std::size_t idx_start = i * local_size;
78+
const std::size_t eq_ids_pos = i * (local_size * local_size);
79+
80+
// Loop over the DOFs
81+
for (unsigned int i_local = 0; i_local < local_size; ++i_local) {
82+
const unsigned int i_global = connectivities_view[idx_start + i_local]; // Row global equation id
83+
for(unsigned int j_local = 0; j_local < local_size; ++j_local) {
84+
const unsigned int j_global = connectivities_view[idx_start + j_local]; // Column global equation id
85+
const unsigned int csr_index = rCsrMatrix.FindValueIndex(i_global,j_global); // Index in the CSR matrix values vector
86+
eq_ids_data_view[eq_ids_pos + i_local * local_size + j_local] = csr_index;
87+
}
88+
}
89+
});
90+
91+
// Return the data container
92+
return eq_ids_data;
93+
}
94+
4795
/**
4896
* @brief Get the Equation Id Csr Indices object
4997
* @details This function computes the CSR matrix value vector indices corresponding to the equation ids of the entities in the provided container.
5098
* @tparam TContainerType Type of the container of the entities.
5199
* @param rCsrMatrix The CSR matrix to get the indices from.
52100
* @param rContainer The container of the entities.
53101
* @param rProcessInfo The process info.
54-
* @param pNDData Pointer to the NDData object to store the indices.
102+
* @return NDData The NDData object containing the indices.
55103
*/
56104
template<class TContainerType>
57105
NDData<int> GetEquationIdCsrIndices(
@@ -73,21 +121,21 @@ namespace
73121
auto eq_ids_data = NDData<int>(nd_data_shape);
74122

75123
// Loop over the container
124+
EquationIdVectorType aux_tls;
76125
auto eq_ids_data_view = eq_ids_data.ViewData();
77-
IndexPartition<std::size_t>(n_entities).for_each([&](std::size_t i) {
126+
IndexPartition<std::size_t>(n_entities).for_each(aux_tls, [&](std::size_t i, EquationIdVectorType& rTLS) {
78127
// Get current entity
79128
auto it = rContainer.begin() + i;
80129
const std::size_t it_pos = i * (local_size * local_size);
81130

82131
// Get current entity equation ids
83-
EquationIdVectorType equation_ids;
84-
it->EquationIdVector(equation_ids, rProcessInfo);
132+
it->EquationIdVector(rTLS, rProcessInfo);
85133

86134
// Loop over the DOFs
87135
for (unsigned int i_local = 0; i_local < local_size; ++i_local) {
88-
const unsigned int i_global = equation_ids[i_local]; // Row global equation id
136+
const unsigned int i_global = rTLS[i_local]; // Row global equation id
89137
for(unsigned int j_local = 0; j_local < local_size; ++j_local) {
90-
const unsigned int j_global = equation_ids[j_local]; // Column global equation id
138+
const unsigned int j_global = rTLS[j_local]; // Column global equation id
91139
const unsigned int csr_index = rCsrMatrix.FindValueIndex(i_global,j_global); // Index in the CSR matrix values vector
92140
eq_ids_data_view[it_pos + i_local * local_size + j_local] = csr_index;
93141
}
@@ -132,7 +180,7 @@ namespace
132180
const std::size_t aux_idx = entity_pos + i_local * local_size_1 + j_local; // Position in the contributions and equation ids data
133181
const int csr_index = r_idx_data[aux_idx]; // Index in the CSR matrix values vector
134182
const double lhs_contribution = r_lhs_contribution_data[aux_idx]; // Scalar contribution to the left hand side
135-
r_lhs_data[csr_index] += lhs_contribution;
183+
AtomicAdd(r_lhs_data[csr_index], lhs_contribution);
136184
}
137185
}
138186
});
@@ -306,6 +354,9 @@ void AddSparseMatricesToPython(pybind11::module& m)
306354
.def("__matmul__", [](CsrMatrix<double,IndexType>& rA,CsrMatrix<double,IndexType>& rB){
307355
return AmgclCSRSpMMUtilities::SparseMultiply(rA,rB);
308356
}, py::is_operator())
357+
.def("GetEquationIdCsrIndices", [](CsrMatrix<double,IndexType>& rA, const NDData<int>& rConnectivities) {
358+
return GetEquationIdCsrIndices(rA, rConnectivities);
359+
}, py::return_value_policy::move)
309360
.def("GetEquationIdCsrIndices", [](CsrMatrix<double,IndexType>& rA, const ElementsContainerType& rElements, const ProcessInfo& rProcessInfo) {
310361
return GetEquationIdCsrIndices(rA, rElements, rProcessInfo);
311362
}, py::return_value_policy::move)

kratos/tests/test_sparse_matrices.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_matrix_assembly(self):
136136
for i in range(y.Size()):
137137
self.assertEqual(c[i], 5.0)
138138

139-
def test_get_equation_id_csr_indices(self):
139+
def test_get_equation_id_csr_indices_1(self):
140140
# Create model part
141141
model = KratosMultiphysics.Model()
142142
mp = model.CreateModelPart("Main")
@@ -187,6 +187,35 @@ def test_get_equation_id_csr_indices(self):
187187
for i, j in zip(data, expected_data):
188188
self.assertEqual(i, j)
189189

190+
def test_get_equation_id_csr_indices_2(self):
191+
# Create elements connectivities
192+
connectivities = KratosMultiphysics.IntNDData(np.array([[0,1,2],[1,3,2]]))
193+
194+
# Build sparse matrix graph
195+
connectivities_data = connectivities.to_numpy()
196+
graph = KratosMultiphysics.SparseContiguousRowGraph(4)
197+
for i in range(connectivities_data.shape[0]):
198+
graph.AddEntries(connectivities_data[i], connectivities_data[i])
199+
graph.Finalize()
200+
201+
# Create Matrix
202+
A = KratosMultiphysics.CsrMatrix(graph)
203+
204+
# Get the elemental contributions CSR indices
205+
elem_csr_indices = A.GetEquationIdCsrIndices(connectivities)
206+
207+
# Check results
208+
shape = elem_csr_indices.to_numpy().shape
209+
self.assertEqual(shape[0], 2)
210+
self.assertEqual(shape[1], 3)
211+
self.assertEqual(shape[2], 3)
212+
213+
data = elem_csr_indices.to_numpy().flatten()
214+
expected_data = [0, 1, 2, 3, 4, 5, 7, 8, 9,
215+
4, 6, 5, 11, 13, 12, 8, 10, 9]
216+
for i, j in zip(data, expected_data):
217+
self.assertEqual(i, j)
218+
190219
def test_assemble_with_csr_indices(self):
191220
# Create model part
192221
model = KratosMultiphysics.Model()

0 commit comments

Comments
 (0)