@@ -1135,7 +1135,7 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
11351135
11361136 ndim = len (self .axes )
11371137 if ndim == 1 :
1138- return self .sparse_lib .csc_matrix (BC )
1138+ mat = self .sparse_lib .csc_matrix (BC )
11391139 elif ndim == 2 :
11401140 axis2 = (axis + 1 ) % ndim
11411141
@@ -1151,8 +1151,8 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
11511151 ] * ndim
11521152 mats [axis ] = self .get_local_slice_of_1D_matrix (BC , axis = axis )
11531153 mats [axis2 ] = Id
1154- return self .sparse_lib .csc_matrix (self .sparse_lib .kron (* mats ))
1155- if ndim == 3 :
1154+ mat = self .sparse_lib .csc_matrix (self .sparse_lib .kron (* mats ))
1155+ elif ndim == 3 :
11561156 mats = [
11571157 None ,
11581158 ] * ndim
@@ -1170,11 +1170,13 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
11701170
11711171 mats [axis ] = self .get_local_slice_of_1D_matrix (BC , axis = axis )
11721172
1173- return self .sparse_lib .csc_matrix (self .sparse_lib .kron (mats [0 ], self .sparse_lib .kron (* mats [1 :])))
1173+ mat = self .sparse_lib .csc_matrix (self .sparse_lib .kron (mats [0 ], self .sparse_lib .kron (* mats [1 :])))
11741174 else :
11751175 raise NotImplementedError (
11761176 f'Matrix expansion for boundary conditions not implemented for { ndim } dimensions!'
11771177 )
1178+ mat = self .eliminate_zeros (mat )
1179+ return mat
11781180
11791181 def remove_BC (self , component , equation , axis , kind , line = - 1 , scalar = False , ** kwargs ):
11801182 """
@@ -1192,6 +1194,7 @@ def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kw
11921194 scalar (bool): Put the BC in all space positions in the other direction
11931195 """
11941196 _BC = self .get_BC (axis = axis , kind = kind , line = line , scalar = scalar , ** kwargs )
1197+ _BC = self .eliminate_zeros (_BC )
11951198 self .BC_mat [self .index (equation )][self .index (component )] -= _BC
11961199
11971200 if scalar :
@@ -1375,7 +1378,7 @@ def put_BCs_in_rhs(self, rhs):
13751378
13761379 return rhs
13771380
1378- def add_equation_lhs (self , A , equation , relations , diag = False ):
1381+ def add_equation_lhs (self , A , equation , relations ):
13791382 """
13801383 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
13811384 that you will convert to an operator later.
@@ -1410,16 +1413,31 @@ def add_equation_lhs(self, A, equation, relations, diag=False):
14101413 A (list of lists of sparse matrices): The operator to be
14111414 equation (str): The equation of the component you want this in
14121415 relations: (dict): Relations between quantities
1413- diag (bool): Whether operator is block-diagonal
14141416 """
14151417 for k , v in relations .items ():
1416- if diag :
1417- assert k == equation , 'You are trying to put a non-diagonal equation into a diagonal operator'
1418- A [self .index (equation )] = v
1419- else :
1420- A [self .index (equation )][self .index (k )] = v
1418+ A [self .index (equation )][self .index (k )] = v
14211419
1422- def convert_operator_matrix_to_operator (self , M , diag = False ):
1420+ def eliminate_zeros (self , A ):
1421+ """
1422+ Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat.
1423+ Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`.
1424+ Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU.
1425+
1426+ Args:
1427+ A: sparse matrix to be pruned
1428+
1429+ Returns:
1430+ CSC sparse matrix
1431+ """
1432+ if self .useGPU :
1433+ A = A .get ()
1434+ A = A .tocsc ()
1435+ A .eliminate_zeros ()
1436+ if self .useGPU :
1437+ A = self .sparse_lib .csc_matrix (A )
1438+ return A
1439+
1440+ def convert_operator_matrix_to_operator (self , M ):
14231441 """
14241442 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
14251443 See documentation of `SpectralHelper.add_equation_lhs` for an example.
@@ -1431,14 +1449,12 @@ def convert_operator_matrix_to_operator(self, M, diag=False):
14311449 sparse linear operator
14321450 """
14331451 if len (self .components ) == 1 :
1434- if diag :
1435- return M [0 ]
1436- else :
1437- return M [0 ][0 ]
1438- elif diag :
1439- return self .sparse_lib .block_diag (M , format = 'csc' )
1452+ op = M [0 ][0 ]
14401453 else :
1441- return self .sparse_lib .block_array (M , format = 'csc' )
1454+ op = self .sparse_lib .bmat (M , format = 'csc' )
1455+
1456+ op = self .eliminate_zeros (op )
1457+ return op
14421458
14431459 def get_wavenumbers (self ):
14441460 """
@@ -1792,7 +1808,7 @@ def expand_matrix_ND(self, matrix, aligned):
17921808 ndim = len (axes ) + 1
17931809
17941810 if ndim == 1 :
1795- return matrix
1811+ mat = matrix
17961812 elif ndim == 2 :
17971813 axis = axes [0 ]
17981814 I1D = sp .eye (self .axes [axis ].N )
@@ -1801,7 +1817,7 @@ def expand_matrix_ND(self, matrix, aligned):
18011817 mats [aligned ] = self .get_local_slice_of_1D_matrix (matrix , aligned )
18021818 mats [axis ] = self .get_local_slice_of_1D_matrix (I1D , axis )
18031819
1804- return sp .kron (* mats )
1820+ mat = sp .kron (* mats )
18051821 elif ndim == 3 :
18061822
18071823 mats = [None ] * ndim
@@ -1810,11 +1826,14 @@ def expand_matrix_ND(self, matrix, aligned):
18101826 I1D = sp .eye (self .axes [axis ].N )
18111827 mats [axis ] = self .get_local_slice_of_1D_matrix (I1D , axis )
18121828
1813- return sp .kron (mats [0 ], sp .kron (* mats [1 :]))
1829+ mat = sp .kron (mats [0 ], sp .kron (* mats [1 :]))
18141830
18151831 else :
18161832 raise NotImplementedError (f'Matrix expansion not implemented for { ndim } dimensions!' )
18171833
1834+ mat = self .eliminate_zeros (mat )
1835+ return mat
1836+
18181837 def get_filter_matrix (self , axis , ** kwargs ):
18191838 """
18201839 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
0 commit comments