@@ -1135,7 +1135,7 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
11351135
11361136 ndim = len (self .axes )
11371137 if ndim == 1 :
1138- return self .sparse_lib .csc_matrix (BC )
1138+ mat = self .sparse_lib .csc_matrix (BC )
11391139 elif ndim == 2 :
11401140 axis2 = (axis + 1 ) % ndim
11411141
@@ -1151,8 +1151,8 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
11511151 ] * ndim
11521152 mats [axis ] = self .get_local_slice_of_1D_matrix (BC , axis = axis )
11531153 mats [axis2 ] = Id
1154- return self .sparse_lib .csc_matrix (self .sparse_lib .kron (* mats ))
1155- if ndim == 3 :
1154+ mat = self .sparse_lib .csc_matrix (self .sparse_lib .kron (* mats ))
1155+ elif ndim == 3 :
11561156 mats = [
11571157 None ,
11581158 ] * ndim
@@ -1170,11 +1170,13 @@ def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
11701170
11711171 mats [axis ] = self .get_local_slice_of_1D_matrix (BC , axis = axis )
11721172
1173- return self .sparse_lib .csc_matrix (self .sparse_lib .kron (mats [0 ], self .sparse_lib .kron (* mats [1 :])))
1173+ mat = self .sparse_lib .csc_matrix (self .sparse_lib .kron (mats [0 ], self .sparse_lib .kron (* mats [1 :])))
11741174 else :
11751175 raise NotImplementedError (
11761176 f'Matrix expansion for boundary conditions not implemented for { ndim } dimensions!'
11771177 )
1178+ mat .eliminate_zeros ()
1179+ return mat
11781180
11791181 def remove_BC (self , component , equation , axis , kind , line = - 1 , scalar = False , ** kwargs ):
11801182 """
@@ -1192,6 +1194,7 @@ def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kw
11921194 scalar (bool): Put the BC in all space positions in the other direction
11931195 """
11941196 _BC = self .get_BC (axis = axis , kind = kind , line = line , scalar = scalar , ** kwargs )
1197+ _BC .eliminate_zeros ()
11951198 self .BC_mat [self .index (equation )][self .index (component )] -= _BC
11961199
11971200 if scalar :
@@ -1375,7 +1378,7 @@ def put_BCs_in_rhs(self, rhs):
13751378
13761379 return rhs
13771380
1378- def add_equation_lhs (self , A , equation , relations , diag = False ):
1381+ def add_equation_lhs (self , A , equation , relations ):
13791382 """
13801383 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
13811384 that you will convert to an operator later.
@@ -1410,16 +1413,11 @@ def add_equation_lhs(self, A, equation, relations, diag=False):
14101413 A (list of lists of sparse matrices): The operator to be
14111414 equation (str): The equation of the component you want this in
14121415 relations: (dict): Relations between quantities
1413- diag (bool): Whether operator is block-diagonal
14141416 """
14151417 for k , v in relations .items ():
1416- if diag :
1417- assert k == equation , 'You are trying to put a non-diagonal equation into a diagonal operator'
1418- A [self .index (equation )] = v
1419- else :
1420- A [self .index (equation )][self .index (k )] = v
1418+ A [self .index (equation )][self .index (k )] = v
14211419
1422- def convert_operator_matrix_to_operator (self , M , diag = False ):
1420+ def convert_operator_matrix_to_operator (self , M ):
14231421 """
14241422 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
14251423 See documentation of `SpectralHelper.add_equation_lhs` for an example.
@@ -1431,14 +1429,12 @@ def convert_operator_matrix_to_operator(self, M, diag=False):
14311429 sparse linear operator
14321430 """
14331431 if len (self .components ) == 1 :
1434- if diag :
1435- return M [0 ]
1436- else :
1437- return M [0 ][0 ]
1438- elif diag :
1439- return self .sparse_lib .block_diag (M , format = 'csc' )
1432+ op = M [0 ][0 ]
14401433 else :
1441- return self .sparse_lib .block_array (M , format = 'csc' )
1434+ op = self .sparse_lib .bmat (M , format = 'csc' )
1435+
1436+ op .eliminate_zeros ()
1437+ return op
14421438
14431439 def get_wavenumbers (self ):
14441440 """
@@ -1792,7 +1788,7 @@ def expand_matrix_ND(self, matrix, aligned):
17921788 ndim = len (axes ) + 1
17931789
17941790 if ndim == 1 :
1795- return matrix
1791+ mat = matrix
17961792 elif ndim == 2 :
17971793 axis = axes [0 ]
17981794 I1D = sp .eye (self .axes [axis ].N )
@@ -1801,7 +1797,7 @@ def expand_matrix_ND(self, matrix, aligned):
18011797 mats [aligned ] = self .get_local_slice_of_1D_matrix (matrix , aligned )
18021798 mats [axis ] = self .get_local_slice_of_1D_matrix (I1D , axis )
18031799
1804- return sp .kron (* mats )
1800+ mat = sp .kron (* mats )
18051801 elif ndim == 3 :
18061802
18071803 mats = [None ] * ndim
@@ -1810,11 +1806,15 @@ def expand_matrix_ND(self, matrix, aligned):
18101806 I1D = sp .eye (self .axes [axis ].N )
18111807 mats [axis ] = self .get_local_slice_of_1D_matrix (I1D , axis )
18121808
1813- return sp .kron (mats [0 ], sp .kron (* mats [1 :]))
1809+ mat = sp .kron (mats [0 ], sp .kron (* mats [1 :]))
18141810
18151811 else :
18161812 raise NotImplementedError (f'Matrix expansion not implemented for { ndim } dimensions!' )
18171813
1814+ mat = mat .tocsc ()
1815+ mat .eliminate_zeros ()
1816+ return mat
1817+
18181818 def get_filter_matrix (self , axis , ** kwargs ):
18191819 """
18201820 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
0 commit comments