Skip to content

Commit 6f8a443

Browse files
authored
issue/925: Speed up scripts/build_ntops.py and src/infiniop/ninetoothed/build.py with concurrent.futures (#926)
1 parent 53a1969 commit 6f8a443

File tree

2 files changed

+68
-37
lines changed

2 files changed

+68
-37
lines changed

scripts/build_ntops.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import concurrent.futures
12
import importlib
23
import pathlib
34

@@ -11,16 +12,27 @@
1112
def _find_and_build_ops():
1213
ops_path = SRC_DIR_PATH / "infiniop" / "ops"
1314

14-
for op_dir in ops_path.iterdir():
15-
ninetoothed_path = op_dir / "ninetoothed"
15+
with concurrent.futures.ProcessPoolExecutor() as executor:
16+
futures = []
1617

17-
if ninetoothed_path.is_dir():
18-
module_path = ninetoothed_path / "build"
19-
relative_path = module_path.relative_to(SRC_DIR_PATH)
20-
import_name = ".".join(relative_path.parts)
21-
module = importlib.import_module(import_name)
18+
for op_dir in ops_path.iterdir():
19+
ninetoothed_path = op_dir / "ninetoothed"
2220

23-
module.build()
21+
if not ninetoothed_path.is_dir():
22+
continue
23+
24+
futures.append(executor.submit(_build, ninetoothed_path))
25+
26+
concurrent.futures.as_completed(futures)
27+
28+
29+
def _build(ninetoothed_path):
30+
module_path = ninetoothed_path / "build"
31+
relative_path = module_path.relative_to(SRC_DIR_PATH)
32+
import_name = ".".join(relative_path.parts)
33+
module = importlib.import_module(import_name)
34+
35+
module.build()
2436

2537

2638
if __name__ == "__main__":

src/infiniop/ninetoothed/build.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import concurrent.futures
12
import functools
23
import inspect
34
import itertools
@@ -16,40 +17,28 @@
1617
def build(premake, constexpr_param_grid, caller, op_name, output_dir):
1718
headers = []
1819
all_param_names = []
20+
combinations = []
1921
launches = []
2022

21-
for combination in _generate_param_value_combinations(constexpr_param_grid):
22-
arrangement, application, tensors = premake(**combination)
23+
with concurrent.futures.ProcessPoolExecutor() as executor:
24+
futures = []
2325

24-
for param_name, param_value in combination.items():
25-
if isinstance(param_value, str):
26-
combination[param_name] = (
27-
f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}"
28-
)
26+
for combination in tuple(
27+
_generate_param_value_combinations(constexpr_param_grid)
28+
):
29+
future = executor.submit(
30+
_make, premake, combination, caller, op_name, output_dir
31+
)
2932

30-
combination = {f"{name}_": value for name, value in combination.items()}
33+
futures.append(future)
3134

32-
kernel_name = f"{op_name}_{_generate_suffix(combination.values())}"
35+
for future in concurrent.futures.as_completed(futures):
36+
header, param_names, combination, launch = future.result()
3337

34-
ninetoothed.make(
35-
arrangement,
36-
application,
37-
tensors,
38-
caller=caller,
39-
kernel_name=kernel_name,
40-
output_dir=output_dir,
41-
)
42-
43-
header = output_dir / f"{kernel_name}.h"
44-
param_names = ("stream",) + tuple(
45-
inspect.signature(application).parameters.keys()
46-
)
47-
launch = f""" if ({_generate_condition(combination)})
48-
return launch_{kernel_name}({", ".join(param_names)});"""
49-
50-
headers.append(header)
51-
all_param_names.append(param_names)
52-
launches.append(launch)
38+
headers.append(header)
39+
all_param_names.append(param_names)
40+
combinations.append(combination)
41+
launches.append(launch)
5342

5443
includes = "\n".join(f'#include "{header}"' for header in headers)
5544

@@ -64,7 +53,7 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir):
6453
"NineToothedStream",
6554
] + ["NineToothedTensor" for _ in range(len(param_names) - 1)]
6655

67-
for param_name in combination:
56+
for param_name in functools.reduce(lambda x, y: x | y, combinations, {}):
6857
param_names.append(param_name)
6958
param_types.append("int")
7059

@@ -97,6 +86,36 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir):
9786
(BUILD_DIRECTORY_PATH / header_file_name).write_text(header_content)
9887

9988

89+
def _make(premake, combination, caller, op_name, output_dir):
90+
arrangement, application, tensors = premake(**combination)
91+
92+
for param_name, param_value in combination.items():
93+
if isinstance(param_value, str):
94+
combination[param_name] = (
95+
f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}"
96+
)
97+
98+
combination = {f"{name}_": value for name, value in combination.items()}
99+
100+
kernel_name = f"{op_name}_{_generate_suffix(combination.values())}"
101+
102+
ninetoothed.make(
103+
arrangement,
104+
application,
105+
tensors,
106+
caller=caller,
107+
kernel_name=kernel_name,
108+
output_dir=output_dir,
109+
)
110+
111+
header = output_dir / f"{kernel_name}.h"
112+
param_names = ("stream",) + tuple(inspect.signature(application).parameters.keys())
113+
launch = f""" if ({_generate_condition(combination)})
114+
return launch_{kernel_name}({", ".join(param_names)});"""
115+
116+
return header, param_names, combination, launch
117+
118+
100119
def _generate_condition(combination):
101120
return " && ".join(f"{param} == {value}" for param, value in combination.items())
102121

0 commit comments

Comments
 (0)