2727import math
2828import os
2929import time as time_mod
30+ from functools import partial
3031
3132import numpy
3233
33- device = os .getenv ("SHARPY_USE_GPU" , "" )
34+ try :
35+ import mpi4py
36+
37+ mpi4py .rc .finalize = False
38+ from mpi4py import MPI
39+
40+ comm_rank = MPI .COMM_WORLD .Get_rank ()
41+ comm = MPI .COMM_WORLD
42+ except ImportError :
43+ comm_rank = 0
44+ comm = None
45+
46+
47+ def info (s ):
48+ if comm_rank == 0 :
49+ print (s )
3450
3551
3652def run (n , backend , datatype , benchmark_mode ):
3753 if backend == "sharpy" :
3854 import sharpy as np
3955 from sharpy import fini , init , sync
40- from sharpy .numpy import fromfunction
56+ from sharpy .numpy import fromfunction as _fromfunction
57+
58+ device = os .getenv ("SHARPY_USE_GPU" , "" )
59+ create_full = partial (np .full , device = device )
60+ fromfunction = partial (_fromfunction , device = device )
4161
4262 all_axes = [0 , 1 ]
4363 init (False )
4464
45- try :
46- import mpi4py
47-
48- mpi4py .rc .finalize = False
49- from mpi4py import MPI
50-
51- comm_rank = MPI .COMM_WORLD .Get_rank ()
52- except ImportError :
53- comm_rank = 0
54-
5565 elif backend == "numpy" :
5666 import numpy as np
5767 from numpy import fromfunction
5868
69+ if comm is not None :
70+ assert (
71+ comm .Get_size () == 1
72+ ), "Numpy backend only supports serial execution."
73+
74+ create_full = np .full
75+
5976 fini = sync = lambda x = None : None
6077 all_axes = None
61- comm_rank = 0
6278 else :
6379 raise ValueError (f'Unknown backend: "{ backend } "' )
6480
65- def info (s ):
66- if comm_rank == 0 :
67- print (s )
68-
6981 info (f"Using backend: { backend } " )
7082
7183 dtype = {
@@ -102,32 +114,24 @@ def info(s):
102114 lambda i , j : xmin + i * dx + dx / 2 ,
103115 (nx , ny ),
104116 dtype = dtype ,
105- device = device ,
106117 )
107118 y_t_2d = fromfunction (
108119 lambda i , j : ymin + j * dy + dy / 2 ,
109120 (nx , ny ),
110121 dtype = dtype ,
111- device = device ,
112- )
113- x_u_2d = fromfunction (
114- lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype , device = device
115122 )
123+ x_u_2d = fromfunction (lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype )
116124 y_u_2d = fromfunction (
117125 lambda i , j : ymin + j * dy + dy / 2 ,
118126 (nx + 1 , ny ),
119127 dtype = dtype ,
120- device = device ,
121128 )
122129 x_v_2d = fromfunction (
123130 lambda i , j : xmin + i * dx + dx / 2 ,
124131 (nx , ny + 1 ),
125132 dtype = dtype ,
126- device = device ,
127- )
128- y_v_2d = fromfunction (
129- lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype , device = device
130133 )
134+ y_v_2d = fromfunction (lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype )
131135
132136 T_shape = (nx , ny )
133137 U_shape = (nx + 1 , ny )
@@ -144,32 +148,32 @@ def info(s):
144148 info (f"Total DOFs: { dofs_T + dofs_U + dofs_V } " )
145149
146150 # prognostic variables: elevation, (u, v) velocity
147- e = np . full (T_shape , 0.0 , dtype , device = device )
148- u = np . full (U_shape , 0.0 , dtype , device = device )
149- v = np . full (V_shape , 0.0 , dtype , device = device )
151+ e = create_full (T_shape , 0.0 , dtype )
152+ u = create_full (U_shape , 0.0 , dtype )
153+ v = create_full (V_shape , 0.0 , dtype )
150154
151155 # potential vorticity
152- q = np . full (F_shape , 0.0 , dtype , device = device )
156+ q = create_full (F_shape , 0.0 , dtype )
153157
154158 # bathymetry
155- h = np . full (T_shape , 0.0 , dtype , device = device )
159+ h = create_full (T_shape , 0.0 , dtype )
156160
157- hu = np . full (U_shape , 0.0 , dtype , device = device )
158- hv = np . full (V_shape , 0.0 , dtype , device = device )
161+ hu = create_full (U_shape , 0.0 , dtype )
162+ hv = create_full (V_shape , 0.0 , dtype )
159163
160- dudy = np . full (F_shape , 0.0 , dtype , device = device )
161- dvdx = np . full (F_shape , 0.0 , dtype , device = device )
164+ dudy = create_full (F_shape , 0.0 , dtype )
165+ dvdx = create_full (F_shape , 0.0 , dtype )
162166
163167 # vector invariant form
164- H_at_f = np . full (F_shape , 0.0 , dtype , device = device )
168+ H_at_f = create_full (F_shape , 0.0 , dtype )
165169
166170 # auxiliary variables for RK time integration
167- e1 = np . full (T_shape , 0.0 , dtype , device = device )
168- u1 = np . full (U_shape , 0.0 , dtype , device = device )
169- v1 = np . full (V_shape , 0.0 , dtype , device = device )
170- e2 = np . full (T_shape , 0.0 , dtype , device = device )
171- u2 = np . full (U_shape , 0.0 , dtype , device = device )
172- v2 = np . full (V_shape , 0.0 , dtype , device = device )
171+ e1 = create_full (T_shape , 0.0 , dtype )
172+ u1 = create_full (U_shape , 0.0 , dtype )
173+ v1 = create_full (V_shape , 0.0 , dtype )
174+ e2 = create_full (T_shape , 0.0 , dtype )
175+ u2 = create_full (U_shape , 0.0 , dtype )
176+ v2 = create_full (V_shape , 0.0 , dtype )
173177
174178 def exact_solution (t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d ):
175179 """
@@ -198,7 +202,7 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
198202 Water depth at rest
199203 """
200204 bath = 1.0
201- return bath * np . full (T_shape , 1.0 , dtype , device = device )
205+ return bath * create_full (T_shape , 1.0 , dtype )
202206
203207 # inital elevation
204208 u0 , v0 , e0 = exact_solution (
0 commit comments