@@ -44,14 +44,62 @@ namespace
4444
4545 using EquationIdVectorType = std::vector<std::size_t >;
4646
47+ /* *
48+ * @brief Get the Equation Id Csr Indices object
49+ * @details This function computes the CSR matrix value vector indices corresponding to the equation ids of the entities in the provided container.
50+ * @tparam TContainerType Type of the container of the entities.
51+ * @param rCsrMatrix The CSR matrix to get the indices from.
52+ * @param rConnectivities The connectivities of the entities.
53+ * @return NDData The NDData object containing the indices.
54+ */
55+ NDData<int > GetEquationIdCsrIndices (
56+ const CsrMatrixType& rCsrMatrix,
57+ const NDData<int >& rConnectivities)
58+ {
59+ // Get shapes from the input connectivities
60+ const auto & r_shape = rConnectivities.Shape ();
61+ KRATOS_ERROR_IF (r_shape.size () != 2 ) << " Input connectivities must have shape (n_entities, local_size)" << std::endl;
62+ const std::size_t n_entities = r_shape[0 ];
63+ const std::size_t local_size = r_shape[1 ];
64+
65+ // Assign the input NDData to have shape: number of entities * local_size * local_size
66+ DenseVector<unsigned int > nd_data_shape (3 );
67+ nd_data_shape[0 ] = n_entities;
68+ nd_data_shape[1 ] = local_size;
69+ nd_data_shape[2 ] = local_size;
70+ auto eq_ids_data = NDData<int >(nd_data_shape);
71+
72+ // Loop over the connectivities
73+ auto eq_ids_data_view = eq_ids_data.ViewData ();
74+ auto connectivities_view = rConnectivities.ViewData ();
75+ IndexPartition<std::size_t >(n_entities).for_each ([&](std::size_t i) {
76+ // Get current entity position in the connectivities view
77+ const std::size_t idx_start = i * local_size;
78+ const std::size_t eq_ids_pos = i * (local_size * local_size);
79+
80+ // Loop over the DOFs
81+ for (unsigned int i_local = 0 ; i_local < local_size; ++i_local) {
82+ const unsigned int i_global = connectivities_view[idx_start + i_local]; // Row global equation id
83+ for (unsigned int j_local = 0 ; j_local < local_size; ++j_local) {
84+ const unsigned int j_global = connectivities_view[idx_start + j_local]; // Column global equation id
85+ const unsigned int csr_index = rCsrMatrix.FindValueIndex (i_global,j_global); // Index in the CSR matrix values vector
86+ eq_ids_data_view[eq_ids_pos + i_local * local_size + j_local] = csr_index;
87+ }
88+ }
89+ });
90+
91+ // Return the data container
92+ return eq_ids_data;
93+ }
94+
4795 /* *
4896 * @brief Get the Equation Id Csr Indices object
4997 * @details This function computes the CSR matrix value vector indices corresponding to the equation ids of the entities in the provided container.
5098 * @tparam TContainerType Type of the container of the entities.
5199 * @param rCsrMatrix The CSR matrix to get the indices from.
52100 * @param rContainer The container of the entities.
53101 * @param rProcessInfo The process info.
54- * @param pNDData Pointer to the NDData object to store the indices.
102+ * @return NDData The NDData object containing the indices.
55103 */
56104 template <class TContainerType >
57105 NDData<int > GetEquationIdCsrIndices (
@@ -73,21 +121,21 @@ namespace
73121 auto eq_ids_data = NDData<int >(nd_data_shape);
74122
75123 // Loop over the container
124+ EquationIdVectorType aux_tls;
76125 auto eq_ids_data_view = eq_ids_data.ViewData ();
77- IndexPartition<std::size_t >(n_entities).for_each ([&](std::size_t i) {
126+ IndexPartition<std::size_t >(n_entities).for_each (aux_tls, [&](std::size_t i, EquationIdVectorType& rTLS ) {
78127 // Get current entity
79128 auto it = rContainer.begin () + i;
80129 const std::size_t it_pos = i * (local_size * local_size);
81130
82131 // Get current entity equation ids
83- EquationIdVectorType equation_ids;
84- it->EquationIdVector (equation_ids, rProcessInfo);
132+ it->EquationIdVector (rTLS, rProcessInfo);
85133
86134 // Loop over the DOFs
87135 for (unsigned int i_local = 0 ; i_local < local_size; ++i_local) {
88- const unsigned int i_global = equation_ids [i_local]; // Row global equation id
136+ const unsigned int i_global = rTLS [i_local]; // Row global equation id
89137 for (unsigned int j_local = 0 ; j_local < local_size; ++j_local) {
90- const unsigned int j_global = equation_ids [j_local]; // Column global equation id
138+ const unsigned int j_global = rTLS [j_local]; // Column global equation id
91139 const unsigned int csr_index = rCsrMatrix.FindValueIndex (i_global,j_global); // Index in the CSR matrix values vector
92140 eq_ids_data_view[it_pos + i_local * local_size + j_local] = csr_index;
93141 }
@@ -132,7 +180,7 @@ namespace
132180 const std::size_t aux_idx = entity_pos + i_local * local_size_1 + j_local; // Position in the contributions and equation ids data
133181 const int csr_index = r_idx_data[aux_idx]; // Index in the CSR matrix values vector
134182 const double lhs_contribution = r_lhs_contribution_data[aux_idx]; // Scalar contribution to the left hand side
135- r_lhs_data[csr_index] += lhs_contribution;
183+ AtomicAdd ( r_lhs_data[csr_index], lhs_contribution) ;
136184 }
137185 }
138186 });
@@ -306,6 +354,9 @@ void AddSparseMatricesToPython(pybind11::module& m)
306354 .def (" __matmul__" , [](CsrMatrix<double ,IndexType>& rA,CsrMatrix<double ,IndexType>& rB){
307355 return AmgclCSRSpMMUtilities::SparseMultiply (rA,rB);
308356 }, py::is_operator ())
357+ .def (" GetEquationIdCsrIndices" , [](CsrMatrix<double ,IndexType>& rA, const NDData<int >& rConnectivities) {
358+ return GetEquationIdCsrIndices (rA, rConnectivities);
359+ }, py::return_value_policy::move)
309360 .def (" GetEquationIdCsrIndices" , [](CsrMatrix<double ,IndexType>& rA, const ElementsContainerType& rElements, const ProcessInfo& rProcessInfo) {
310361 return GetEquationIdCsrIndices (rA, rElements, rProcessInfo);
311362 }, py::return_value_policy::move)
0 commit comments