@@ -45,9 +45,7 @@ def get_local_mesh(FFT, L):
4545 """Returns local mesh."""
4646 X = np .ogrid [FFT .local_slice (False )]
4747 N = FFT .global_shape ()
48- for i in range (len (N )):
49- X [i ] = (X [i ]* L [i ]/ N [i ])
50- X = [np .broadcast_to (x , FFT .shape (False )) for x in X ]
48+ X = [np .broadcast_to (x * L [i ]/ N [i ], FFT .shape (False )) for i , x in enumerate (X )]
5149 return X
5250
5351def get_local_wavenumbermesh (FFT , L ):
@@ -60,9 +58,7 @@ def get_local_wavenumbermesh(FFT, L):
6058 K = [ki [si ] for ki , si in zip (k , s )]
6159 Ks = np .meshgrid (* K , indexing = 'ij' , sparse = True )
6260 Lp = 2 * np .pi / L
63- for i in range (3 ):
64- Ks [i ] = (Ks [i ]* Lp [i ]).astype (float )
65- return [np .broadcast_to (k , FFT .shape (True )) for k in Ks ]
61+ return [np .broadcast_to (k * Lp [i ], FFT .shape (True )) for i , k in enumerate (Ks )]
6662
6763X = get_local_mesh (FFT , L )
6864K = get_local_wavenumbermesh (FFT , L )
@@ -131,3 +127,5 @@ def compute_rhs(rhs):
131127if MPI .COMM_WORLD .Get_rank () == 0 :
132128 print ('Time = {}' .format (time ()- t0 ))
133129 assert round (float (k ) - 0.124953117517 , 7 ) == 0
130+
131+ FFT .destroy ()
0 commit comments