@@ -1175,7 +1175,7 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
11751175 raise NotImplementedError (
11761176 f'Matrix expansion for boundary conditions not implemented for { ndim } dimensions!'
11771177 )
1178- mat .eliminate_zeros ()
1178+ mat = self .eliminate_zeros (mat )
11791179 return mat
11801180
11811181 def remove_BC (self , component , equation , axis , kind , line = - 1 , scalar = False , ** kwargs ):
@@ -1194,7 +1194,7 @@ def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kw
11941194 scalar (bool): Put the BC in all space positions in the other direction
11951195 """
11961196 _BC = self .get_BC (axis = axis , kind = kind , line = line , scalar = scalar , ** kwargs )
1197- _BC .eliminate_zeros ()
1197+ _BC = self .eliminate_zeros (_BC )
11981198 self .BC_mat [self .index (equation )][self .index (component )] -= _BC
11991199
12001200 if scalar :
@@ -1417,6 +1417,26 @@ def add_equation_lhs(self, A, equation, relations):
14171417 for k , v in relations .items ():
14181418 A [self .index (equation )][self .index (k )] = v
14191419
1420+ def eliminate_zeros (self , A ):
1421+ """
1422+ Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat.
1423+ Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`.
1424+ Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU.
1425+
1426+ Args:
1427+ A: sparse matrix to be pruned
1428+
1429+ Returns:
1430+ CSC sparse matrix
1431+ """
1432+ if self .useGPU :
1433+ A = A .get ()
1434+ A = A .tocsc ()
1435+ A .eliminate_zeros ()
1436+ if self .useGPU :
1437+ A = self .sparse_lib .csc_matrix (A )
1438+ return A
1439+
14201440 def convert_operator_matrix_to_operator (self , M ):
14211441 """
14221442 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
@@ -1433,7 +1453,7 @@ def convert_operator_matrix_to_operator(self, M):
14331453 else :
14341454 op = self .sparse_lib .bmat (M , format = 'csc' )
14351455
1436- op .eliminate_zeros ()
1456+ op = self .eliminate_zeros (op )
14371457 return op
14381458
14391459 def get_wavenumbers (self ):
@@ -1811,8 +1831,7 @@ def expand_matrix_ND(self, matrix, aligned):
18111831 else :
18121832 raise NotImplementedError (f'Matrix expansion not implemented for { ndim } dimensions!' )
18131833
1814- mat = mat .tocsc ()
1815- mat .eliminate_zeros ()
1834+ mat = self .eliminate_zeros (mat )
18161835 return mat
18171836
18181837 def get_filter_matrix (self , axis , ** kwargs ):
0 commit comments