2828import time as time_mod
2929import argparse
3030
31+ import os
32+
33+ device = os .getenv ("SHARPY_USE_GPU" , "" )
34+
3135
3236def run (n , backend , datatype , benchmark_mode ):
3337 if backend == "sharpy" :
@@ -94,16 +98,24 @@ def info(s):
9498 t_end = 1.0
9599
96100 # coordinate arrays
97- x_t_2d = fromfunction (lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype )
98- y_t_2d = fromfunction (lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype )
99- x_u_2d = fromfunction (lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype )
101+ x_t_2d = fromfunction (
102+ lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype , device = device
103+ )
104+ y_t_2d = fromfunction (
105+ lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype , device = device
106+ )
107+ x_u_2d = fromfunction (
108+ lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype , device = device
109+ )
100110 y_u_2d = fromfunction (
101- lambda i , j : ymin + j * dy + dy / 2 , (nx + 1 , ny ), dtype = dtype
111+ lambda i , j : ymin + j * dy + dy / 2 , (nx + 1 , ny ), dtype = dtype , device = device
102112 )
103113 x_v_2d = fromfunction (
104- lambda i , j : xmin + i * dx + dx / 2 , (nx , ny + 1 ), dtype = dtype
114+ lambda i , j : xmin + i * dx + dx / 2 , (nx , ny + 1 ), dtype = dtype , device = device
115+ )
116+ y_v_2d = fromfunction (
117+ lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype , device = device
105118 )
106- y_v_2d = fromfunction (lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype )
107119
108120 T_shape = (nx , ny )
109121 U_shape = (nx + 1 , ny )
@@ -120,32 +132,32 @@ def info(s):
120132 info (f"Total DOFs: { dofs_T + dofs_U + dofs_V } " )
121133
122134 # prognostic variables: elevation, (u, v) velocity
123- e = np .full (T_shape , 0.0 , dtype )
124- u = np .full (U_shape , 0.0 , dtype )
125- v = np .full (V_shape , 0.0 , dtype )
135+ e = np .full (T_shape , 0.0 , dtype , device = device )
136+ u = np .full (U_shape , 0.0 , dtype , device = device )
137+ v = np .full (V_shape , 0.0 , dtype , device = device )
126138
127139 # potential vorticity
128- q = np .full (F_shape , 0.0 , dtype )
140+ q = np .full (F_shape , 0.0 , dtype , device = device )
129141
130142 # bathymetry
131- h = np .full (T_shape , 0.0 , dtype )
143+ h = np .full (T_shape , 0.0 , dtype , device = device )
132144
133- hu = np .full (U_shape , 0.0 , dtype )
134- hv = np .full (V_shape , 0.0 , dtype )
145+ hu = np .full (U_shape , 0.0 , dtype , device = device )
146+ hv = np .full (V_shape , 0.0 , dtype , device = device )
135147
136- dudy = np .full (F_shape , 0.0 , dtype )
137- dvdx = np .full (F_shape , 0.0 , dtype )
148+ dudy = np .full (F_shape , 0.0 , dtype , device = device )
149+ dvdx = np .full (F_shape , 0.0 , dtype , device = device )
138150
139151 # vector invariant form
140- H_at_f = np .full (F_shape , 0.0 , dtype )
152+ H_at_f = np .full (F_shape , 0.0 , dtype , device = device )
141153
142154 # auxiliary variables for RK time integration
143- e1 = np .full (T_shape , 0.0 , dtype )
144- u1 = np .full (U_shape , 0.0 , dtype )
145- v1 = np .full (V_shape , 0.0 , dtype )
146- e2 = np .full (T_shape , 0.0 , dtype )
147- u2 = np .full (U_shape , 0.0 , dtype )
148- v2 = np .full (V_shape , 0.0 , dtype )
155+ e1 = np .full (T_shape , 0.0 , dtype , device = device )
156+ u1 = np .full (U_shape , 0.0 , dtype , device = device )
157+ v1 = np .full (V_shape , 0.0 , dtype , device = device )
158+ e2 = np .full (T_shape , 0.0 , dtype , device = device )
159+ u2 = np .full (U_shape , 0.0 , dtype , device = device )
160+ v2 = np .full (V_shape , 0.0 , dtype , device = device )
149161
150162 def exact_solution (t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d ):
151163 """
@@ -174,7 +186,7 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
174186 Water depth at rest
175187 """
176188 bath = 1.0
177- return bath * np .full (T_shape , 1.0 , dtype )
189+ return bath * np .full (T_shape , 1.0 , dtype , device = device )
178190
179191 # inital elevation
180192 u0 , v0 , e0 = exact_solution (0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d )
0 commit comments