@@ -1178,30 +1178,36 @@ def put_BCs_in_matrix(self, A):
11781178 def put_BCs_in_rhs_hat (self , rhs_hat ):
11791179 """
11801180 Put the BCs in the right hand side in spectral space for solving.
1181- This function needs no transforms.
1181+ This function needs no transforms and caches a mask for faster subsequent use .
11821182
11831183 Args:
11841184 rhs_hat: Right hand side in spectral space
11851185
11861186 Returns:
11871187 rhs in spectral space with BCs
11881188 """
1189- ndim = self .ndim
1190-
1191- for axis in range (ndim ):
1192- for bc in self .full_BCs :
1193- slices = (
1194- [slice (0 , self .init [0 ][i + 1 ]) for i in range (axis )]
1195- + [bc ['line' ]]
1196- + [slice (0 , self .init [0 ][i + 1 ]) for i in range (axis + 1 , len (self .axes ))]
1197- )
1198- if axis == bc ['axis' ]:
1199- _slice = [self .index (bc ['equation' ])] + slices
1200- N = self .axes [axis ].N
1201- if (N + bc ['line' ]) % N in self .xp .arange (N )[self .local_slice [axis ]]:
1202- _slice [axis + 1 ] -= self .local_slice [axis ].start
1203- rhs_hat [(* _slice ,)] = 0
1204-
1189+ if not hasattr (self , '_rhs_hat_zero_mask' ):
1190+ """
1191+ Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1192+ by the boundary conditions. The mask is then cached.
1193+ """
1194+ self ._rhs_hat_zero_mask = self .xp .zeros (shape = rhs_hat .shape , dtype = bool )
1195+
1196+ for axis in range (self .ndim ):
1197+ for bc in self .full_BCs :
1198+ slices = (
1199+ [slice (0 , self .init [0 ][i + 1 ]) for i in range (axis )]
1200+ + [bc ['line' ]]
1201+ + [slice (0 , self .init [0 ][i + 1 ]) for i in range (axis + 1 , len (self .axes ))]
1202+ )
1203+ if axis == bc ['axis' ]:
1204+ _slice = [self .index (bc ['equation' ])] + slices
1205+ N = self .axes [axis ].N
1206+ if (N + bc ['line' ]) % N in self .xp .arange (N )[self .local_slice [axis ]]:
1207+ _slice [axis + 1 ] -= self .local_slice [axis ].start
1208+ self ._rhs_hat_zero_mask [(* _slice ,)] = True
1209+
1210+ rhs_hat [self ._rhs_hat_zero_mask ] = 0
12051211 return rhs_hat + self .rhs_BCs_hat
12061212
12071213 def put_BCs_in_rhs (self , rhs ):
@@ -1347,18 +1353,22 @@ def get_fft(self, axes=None, direction='object', padding=None, shape=None):
13471353 elif direction == 'object' :
13481354 self .fft_cache [key ] = None
13491355 else :
1350- from mpi4py_fft import PFFT
1351-
1352- _fft = PFFT (
1353- comm = self .comm ,
1354- shape = shape ,
1355- axes = sorted (axes ),
1356- dtype = 'D' ,
1357- collapse = False ,
1358- backend = self .fft_backend ,
1359- comm_backend = self .fft_comm_backend ,
1360- padding = padding ,
1361- )
1356+ if direction == 'object' :
1357+ from mpi4py_fft import PFFT
1358+
1359+ _fft = PFFT (
1360+ comm = self .comm ,
1361+ shape = shape ,
1362+ axes = sorted (axes ),
1363+ dtype = 'D' ,
1364+ collapse = False ,
1365+ backend = self .fft_backend ,
1366+ comm_backend = self .fft_comm_backend ,
1367+ padding = padding ,
1368+ )
1369+ else :
1370+ _fft = self .get_fft (axes = axes , direction = 'object' , padding = padding , shape = shape )
1371+
13621372 if direction == 'forward' :
13631373 self .fft_cache [key ] = _fft .forward
13641374 elif direction == 'backward' :
0 commit comments