Skip to content

Commit 3185c67

Browse files
committed
Improve generate script
- Fix issue with __pycache__ dirs getting picked up - parallelise code generation with asyncio for 3x speedup - silence protoc output unless -v option is supplied - Use pathlib ;)
1 parent 4b6f55d commit 3185c67

File tree

3 files changed

+99
-72
lines changed

3 files changed

+99
-72
lines changed

betterproto/tests/generate.py

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
2-
import glob
2+
import asyncio
33
import os
4+
from pathlib import Path
45
import shutil
56
import subprocess
67
import sys
@@ -20,91 +21,122 @@
2021
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
2122

2223

23-
def clear_directory(path: str):
24-
for file_or_directory in glob.glob(os.path.join(path, "*")):
25-
if os.path.isdir(file_or_directory):
24+
def clear_directory(dir_path: Path):
25+
for file_or_directory in dir_path.glob("*"):
26+
if file_or_directory.is_dir():
2627
shutil.rmtree(file_or_directory)
2728
else:
28-
os.remove(file_or_directory)
29+
file_or_directory.unlink()
2930

3031

31-
def generate(whitelist: Set[str]):
32-
path_whitelist = {os.path.realpath(e) for e in whitelist if os.path.exists(e)}
33-
name_whitelist = {e for e in whitelist if not os.path.exists(e)}
32+
async def generate(whitelist: Set[str], verbose: bool):
33+
test_case_names = set(get_directories(inputs_path)) - {"__pycache__"}
3434

35-
test_case_names = set(get_directories(inputs_path))
36-
37-
failed_test_cases = []
35+
path_whitelist = set()
36+
name_whitelist = set()
37+
for item in whitelist:
38+
if item in test_case_names:
39+
name_whitelist.add(item)
40+
continue
41+
path_whitelist.add(item)
3842

43+
generation_tasks = []
3944
for test_case_name in sorted(test_case_names):
40-
test_case_input_path = os.path.realpath(
41-
os.path.join(inputs_path, test_case_name)
42-
)
43-
45+
test_case_input_path = inputs_path.joinpath(test_case_name).resolve()
4446
if (
4547
whitelist
46-
and test_case_input_path not in path_whitelist
48+
and str(test_case_input_path) not in path_whitelist
4749
and test_case_name not in name_whitelist
4850
):
4951
continue
52+
generation_tasks.append(
53+
generate_test_case_output(test_case_input_path, test_case_name, verbose)
54+
)
5055

51-
print(f"Generating output for {test_case_name}")
52-
try:
53-
generate_test_case_output(test_case_name, test_case_input_path)
54-
except subprocess.CalledProcessError as e:
56+
failed_test_cases = []
57+
# Wait for all subprocs and match any failures to names to report
58+
for test_case_name, result in zip(
59+
sorted(test_case_names), await asyncio.gather(*generation_tasks)
60+
):
61+
if result != 0:
5562
failed_test_cases.append(test_case_name)
5663

5764
if failed_test_cases:
58-
sys.stderr.write("\nFailed to generate the following test cases:\n")
65+
sys.stderr.write(
66+
"\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
67+
)
5968
for failed_test_case in failed_test_cases:
6069
sys.stderr.write(f"- {failed_test_case}\n")
6170

6271

63-
def generate_test_case_output(test_case_name, test_case_input_path=None):
64-
if not test_case_input_path:
65-
test_case_input_path = os.path.realpath(
66-
os.path.join(inputs_path, test_case_name)
67-
)
72+
async def generate_test_case_output(
73+
test_case_input_path: Path, test_case_name: str, verbose: bool
74+
) -> int:
75+
"""
76+
Returns the max of the subprocess return values
77+
"""
6878

69-
test_case_output_path_reference = os.path.join(
70-
output_path_reference, test_case_name
71-
)
72-
test_case_output_path_betterproto = os.path.join(
73-
output_path_betterproto, test_case_name
74-
)
79+
test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
80+
test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name)
7581

7682
os.makedirs(test_case_output_path_reference, exist_ok=True)
7783
os.makedirs(test_case_output_path_betterproto, exist_ok=True)
7884

7985
clear_directory(test_case_output_path_reference)
8086
clear_directory(test_case_output_path_betterproto)
8187

82-
protoc_reference(test_case_input_path, test_case_output_path_reference)
83-
protoc_plugin(test_case_input_path, test_case_output_path_betterproto)
88+
(
89+
(ref_out, ref_err, ref_code),
90+
(plg_out, plg_err, plg_code),
91+
) = await asyncio.gather(
92+
protoc_reference(test_case_input_path, test_case_output_path_reference),
93+
protoc_plugin(test_case_input_path, test_case_output_path_betterproto),
94+
)
95+
96+
message = f"Generated output for {test_case_name!r}"
97+
if verbose:
98+
print(f"\033[31;1;4m{message}\033[0m")
99+
if ref_out:
100+
sys.stdout.buffer.write(ref_out)
101+
if ref_err:
102+
sys.stderr.buffer.write(ref_err)
103+
if plg_out:
104+
sys.stdout.buffer.write(plg_out)
105+
if plg_err:
106+
sys.stderr.buffer.write(plg_err)
107+
sys.stdout.buffer.flush()
108+
sys.stderr.buffer.flush()
109+
else:
110+
print(message)
111+
112+
return max(ref_code, plg_code)
84113

85114

86115
HELP = "\n".join(
87-
[
88-
"Usage: python generate.py",
89-
" python generate.py [DIRECTORIES or NAMES]",
116+
(
117+
"Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]",
90118
"Generate python classes for standard tests.",
91119
"",
92120
"DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
93121
" python generate.py inputs/bool inputs/double inputs/enum",
94122
"",
95123
"NAMES One or more test-case names to generate classes for.",
96124
" python generate.py bool double enums",
97-
]
125+
)
98126
)
99127

100128

101129
def main():
102130
if set(sys.argv).intersection({"-h", "--help"}):
103131
print(HELP)
104132
return
105-
whitelist = set(sys.argv[1:])
106-
107-
generate(whitelist)
133+
if sys.argv[1:2] == ["-v"]:
134+
verbose = True
135+
whitelist = set(sys.argv[2:])
136+
else:
137+
verbose = False
138+
whitelist = set(sys.argv[1:])
139+
asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose))
108140

109141

110142
if __name__ == "__main__":

betterproto/tests/test_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
class TestCases:
2525
def __init__(self, path, services: Set[str], xfail: Set[str]):
26-
_all = set(get_directories(path))
26+
_all = set(get_directories(path)) - {"__pycache__"}
2727
_services = services
2828
_messages = _all - services
2929
_messages_with_json = {

betterproto/tests/util.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
1+
import asyncio
12
import os
2-
import subprocess
3-
from typing import Generator
3+
from pathlib import Path
4+
from typing import Generator, IO, Optional
45

56
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
67

7-
root_path = os.path.dirname(os.path.realpath(__file__))
8-
inputs_path = os.path.join(root_path, "inputs")
9-
output_path_reference = os.path.join(root_path, "output_reference")
10-
output_path_betterproto = os.path.join(root_path, "output_betterproto")
8+
root_path = Path(__file__).resolve().parent
9+
inputs_path = root_path.joinpath("inputs")
10+
output_path_reference = root_path.joinpath("output_reference")
11+
output_path_betterproto = root_path.joinpath("output_betterproto")
1112

1213
if os.name == "nt":
13-
plugin_path = os.path.join(root_path, "..", "plugin.bat")
14+
plugin_path = root_path.joinpath("..", "plugin.bat").resolve()
1415
else:
15-
plugin_path = os.path.join(root_path, "..", "plugin.py")
16+
plugin_path = root_path.joinpath("..", "plugin.py").resolve()
1617

1718

18-
def get_files(path, end: str) -> Generator[str, None, None]:
19+
def get_files(path, suffix: str) -> Generator[str, None, None]:
1920
for r, dirs, files in os.walk(path):
20-
for filename in [f for f in files if f.endswith(end)]:
21+
for filename in [f for f in files if f.endswith(suffix)]:
2122
yield os.path.join(r, filename)
2223

2324

@@ -27,36 +28,30 @@ def get_directories(path):
2728
yield directory
2829

2930

30-
def relative(file: str, path: str):
31-
return os.path.join(os.path.dirname(file), path)
32-
33-
34-
def read_relative(file: str, path: str):
35-
with open(relative(file, path)) as fh:
36-
return fh.read()
37-
38-
39-
def protoc_plugin(path: str, output_dir: str) -> subprocess.CompletedProcess:
40-
return subprocess.run(
31+
async def protoc_plugin(path: str, output_dir: str):
32+
proc = await asyncio.create_subprocess_shell(
4133
f"protoc --plugin=protoc-gen-custom={plugin_path} --custom_out={output_dir} --proto_path={path} {path}/*.proto",
42-
shell=True,
43-
check=True,
34+
stdout=asyncio.subprocess.PIPE,
35+
stderr=asyncio.subprocess.PIPE,
4436
)
37+
return (*(await proc.communicate()), proc.returncode)
4538

4639

47-
def protoc_reference(path: str, output_dir: str):
48-
subprocess.run(
40+
async def protoc_reference(path: str, output_dir: str):
41+
proc = await asyncio.create_subprocess_shell(
4942
f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto",
50-
shell=True,
43+
stdout=asyncio.subprocess.PIPE,
44+
stderr=asyncio.subprocess.PIPE,
5145
)
46+
return (*(await proc.communicate()), proc.returncode)
5247

5348

54-
def get_test_case_json_data(test_case_name, json_file_name=None):
49+
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
5550
test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json"
56-
test_data_file_path = os.path.join(inputs_path, test_case_name, test_data_file_name)
51+
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
5752

58-
if not os.path.exists(test_data_file_path):
53+
if not test_data_file_path.exists():
5954
return None
6055

61-
with open(test_data_file_path) as fh:
56+
with test_data_file_path.open("r") as fh:
6257
return fh.read()

0 commit comments

Comments
 (0)