Skip to content

Commit decd4c6

Browse files
chore: lint
1 parent 50c2b0b commit decd4c6

File tree

8 files changed

+299
-197
lines changed

8 files changed

+299
-197
lines changed

miniwave/miniwave.py

Lines changed: 77 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,102 +4,132 @@
44
from utils import Model, Compiler, Kernel, plot
55
from utils.properties import Properties
66

7+
78
def get_args(args=sys.argv[1:]):
8-
9+
910
parser = argparse.ArgumentParser(description='How to use this program')
10-
11+
1112
parser.add_argument("--file", type=str, default='kernels/sequential.c',
1213
help="Path to the Kernel file")
13-
14+
1415
parser.add_argument("--grid_size", type=int, default=256,
1516
help="Grid size")
16-
17+
1718
parser.add_argument("--num_timesteps", type=int, default=400,
1819
help="Number of timesteps")
19-
20-
parser.add_argument("--language", type=str, default="c", choices=['c', 'openmp', 'openmp_cpu', 'openacc', 'cuda', 'python', 'mpi', 'mpi_cuda', 'ompc'],
21-
help="Language: c, openmp, openacc, cuda, python, ompc, mpi, mpi_cuda")
22-
20+
21+
parser.add_argument(
22+
"--language",
23+
type=str,
24+
default="c",
25+
choices=[
26+
'c',
27+
'openmp',
28+
'openmp_cpu',
29+
'openacc',
30+
'cuda',
31+
'python',
32+
'mpi',
33+
'mpi_cuda',
34+
'ompc'
35+
],
36+
help="Language: c, openmp, openacc, cuda, python, ompc, mpi, mpi_cuda"
37+
)
38+
2339
parser.add_argument("--space_order", type=int, default=2,
2440
help="Space order")
25-
41+
2642
parser.add_argument("--block_size_1", type=int, default=1,
2743
help="GPU block size in the first axis")
28-
44+
2945
parser.add_argument("--block_size_2", type=int, default=1,
3046
help="GPU block size in the second axis")
31-
47+
3248
parser.add_argument("--block_size_3", type=int, default=1,
33-
help="GPU block size in the third axis")
34-
49+
help="GPU block size in the third axis")
50+
3551
parser.add_argument("--sm", type=int, default=75,
36-
help="Cuda capability")
37-
38-
parser.add_argument("--fast_math", default=False, action="store_true" , help="Enable --fast-math flag")
39-
40-
parser.add_argument("--plot", default=False, action="store_true" , help="Enable ploting")
41-
42-
parser.add_argument("--dtype", type=str, default="float64", help="Float Precision. float32 or float64 (default)")
43-
44-
52+
help="Cuda capability")
53+
54+
parser.add_argument(
55+
"--fast_math",
56+
default=False,
57+
action="store_true",
58+
help="Enable --fast-math flag"
59+
)
60+
61+
parser.add_argument(
62+
"--plot",
63+
default=False,
64+
action="store_true",
65+
help="Enable ploting"
66+
)
67+
68+
parser.add_argument(
69+
"--dtype",
70+
type=str,
71+
default="float64",
72+
help="Float Precision. float32 or float64 (default)"
73+
)
74+
4575
parsed_args = parser.parse_args(args)
4676

4777
return parsed_args
4878

4979

5080
if __name__ == "__main__":
51-
81+
5282
args = get_args()
53-
54-
# enable/disable fast math
83+
84+
# enable/disable fast math
5585
fast_math = args.fast_math
56-
86+
5787
# cuda capability
58-
sm = args.sm
59-
88+
sm = args.sm
89+
6090
# language
6191
language = args.language
62-
92+
6393
# float precision
6494
dtype = args.dtype
65-
95+
6696
# create a compiler object
6797
compiler = Compiler(language=language, sm=sm, fast_math=fast_math)
68-
98+
6999
# define grid shape
70100
grid_size = (args.grid_size, args.grid_size, args.grid_size)
71-
101+
72102
vel_model = np.ones(shape=grid_size) * 1500.0
73-
103+
74104
model = Model(
75105
velocity_model=vel_model,
76-
grid_spacing=(10,10,10),
106+
grid_spacing=(10, 10, 10),
77107
dt=0.002,
78108
num_timesteps=args.num_timesteps,
79109
space_order=args.space_order,
80110
dtype=dtype
81-
)
82-
83-
# GPU block sizes
111+
)
112+
113+
# GPU block sizes
84114
properties = Properties(
85115
block_size_1=args.block_size_1,
86-
block_size_2=args.block_size_2,
116+
block_size_2=args.block_size_2,
87117
block_size_3=args.block_size_3
88-
)
89-
118+
)
119+
90120
solver = Kernel(
91-
file=args.file,
121+
file=args.file,
92122
model=model,
93123
compiler=compiler,
94124
properties=properties
95-
)
96-
125+
)
126+
97127
# run the kernel
98128
exec_time, u = solver.run()
99-
129+
100130
# plot a slice
101131
if args.plot:
102-
slice = vel_model.shape[1] // 2
103-
plot(u[:,slice,:])
104-
132+
slice = vel_model.shape[1] // 2
133+
plot(u[:, slice, :])
134+
105135
print(f"Execution time: {exec_time} seconds")

miniwave/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
# flake8: noqa
12
from .model import Model
23
from .compiler import Compiler
34
from .kernel import Kernel
45
from .plot import plot
5-
from .properties import Properties
6+
from .properties import Properties

0 commit comments

Comments
 (0)