|
1 | 1 | #!/usr/bin/env python
|
2 | 2 | import os
|
| 3 | +import sys |
| 4 | +from typing import Set |
| 5 | + |
| 6 | +from betterproto.tests.util import get_directories, inputs_path, output_path_betterproto, output_path_reference, \ |
| 7 | + protoc_plugin, protoc_reference |
3 | 8 |
|
4 | 9 | # Force pure-python implementation instead of C++, otherwise imports
|
5 | 10 | # break things because we can't properly reset the symbol database.
|
6 | 11 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
7 | 12 |
|
8 |
| -import importlib |
9 |
| -import json |
10 |
| -import subprocess |
11 |
| -import sys |
12 |
| -from typing import Generator, Tuple |
13 | 13 |
|
14 |
| -from google.protobuf import symbol_database |
15 |
| -from google.protobuf.descriptor_pool import DescriptorPool |
16 |
| -from google.protobuf.json_format import MessageToJson, Parse |
| 14 | +def generate(whitelist: Set[str]): |
| 15 | + path_whitelist = {os.path.realpath(e) for e in whitelist if os.path.exists(e)} |
| 16 | + name_whitelist = {e for e in whitelist if not os.path.exists(e)} |
| 17 | + |
| 18 | + test_case_names = set(get_directories(inputs_path)) |
| 19 | + |
| 20 | + for test_case_name in sorted(test_case_names): |
| 21 | + test_case_path = os.path.realpath(os.path.join(inputs_path, test_case_name)) |
| 22 | + |
| 23 | + if whitelist and test_case_path not in path_whitelist and test_case_name not in name_whitelist: |
| 24 | + continue |
17 | 25 |
|
| 26 | + case_output_dir_reference = os.path.join(output_path_reference, test_case_name) |
| 27 | + case_output_dir_betterproto = os.path.join(output_path_betterproto, test_case_name) |
18 | 28 |
|
19 |
| -root = os.path.dirname(os.path.realpath(__file__)) |
| 29 | + print(f'Generating output for {test_case_name}') |
| 30 | + os.makedirs(case_output_dir_reference, exist_ok=True) |
| 31 | + os.makedirs(case_output_dir_betterproto, exist_ok=True) |
20 | 32 |
|
| 33 | + protoc_reference(test_case_path, case_output_dir_reference) |
| 34 | + protoc_plugin(test_case_path, case_output_dir_betterproto) |
21 | 35 |
|
22 |
| -def get_files(end: str) -> Generator[str, None, None]: |
23 |
| - for r, dirs, files in os.walk(root): |
24 |
| - for filename in [f for f in files if f.endswith(end)]: |
25 |
| - yield os.path.join(r, filename) |
26 | 36 |
|
| 37 | +HELP = "\n".join([ |
| 38 | + 'Usage: python generate.py', |
| 39 | + ' python generate.py [DIRECTORIES or NAMES]', |
| 40 | + 'Generate python classes for standard tests.', |
| 41 | + '', |
| 42 | + 'DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.', |
| 43 | + ' python generate.py inputs/bool inputs/double inputs/enum', |
| 44 | + '', |
| 45 | + 'NAMES One or more test-case names to generate classes for.', |
| 46 | + ' python generate.py bool double enums' |
| 47 | +]) |
27 | 48 |
|
28 |
| -def get_base(filename: str) -> str: |
29 |
| - return os.path.splitext(os.path.basename(filename))[0] |
30 | 49 |
|
| 50 | +def main(): |
| 51 | + if set(sys.argv).intersection({'-h', '--help'}): |
| 52 | + print(HELP) |
| 53 | + return |
| 54 | + whitelist = set(sys.argv[1:]) |
31 | 55 |
|
32 |
| -def ensure_ext(filename: str, ext: str) -> str: |
33 |
| - if not filename.endswith(ext): |
34 |
| - return filename + ext |
35 |
| - return filename |
| 56 | + generate(whitelist) |
36 | 57 |
|
37 | 58 |
|
38 | 59 | if __name__ == "__main__":
|
39 |
| - os.chdir(root) |
40 |
| - |
41 |
| - if len(sys.argv) > 1: |
42 |
| - proto_files = [ensure_ext(f, ".proto") for f in sys.argv[1:]] |
43 |
| - bases = {get_base(f) for f in proto_files} |
44 |
| - json_files = [ |
45 |
| - f for f in get_files(".json") if get_base(f).split("-")[0] in bases |
46 |
| - ] |
47 |
| - else: |
48 |
| - proto_files = get_files(".proto") |
49 |
| - json_files = get_files(".json") |
50 |
| - |
51 |
| - for filename in proto_files: |
52 |
| - print(f"Generating code for {os.path.basename(filename)}") |
53 |
| - subprocess.run( |
54 |
| - f"protoc --python_out=. {os.path.basename(filename)}", shell=True |
55 |
| - ) |
56 |
| - subprocess.run( |
57 |
| - f"protoc --plugin=protoc-gen-custom=../plugin.py --custom_out=. {os.path.basename(filename)}", |
58 |
| - shell=True, |
59 |
| - ) |
60 |
| - |
61 |
| - for filename in json_files: |
62 |
| - # Reset the internal symbol database so we can import the `Test` message |
63 |
| - # multiple times. Ugh. |
64 |
| - sym = symbol_database.Default() |
65 |
| - sym.pool = DescriptorPool() |
66 |
| - |
67 |
| - parts = get_base(filename).split("-") |
68 |
| - out = filename.replace(".json", ".bin") |
69 |
| - print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}") |
70 |
| - |
71 |
| - imported = importlib.import_module(f"{parts[0]}_pb2") |
72 |
| - input_json = open(filename).read() |
73 |
| - parsed = Parse(input_json, imported.Test()) |
74 |
| - serialized = parsed.SerializeToString() |
75 |
| - preserve = "casing" not in filename |
76 |
| - serialized_json = MessageToJson(parsed, preserving_proto_field_name=preserve) |
77 |
| - |
78 |
| - s_loaded = json.loads(serialized_json) |
79 |
| - in_loaded = json.loads(input_json) |
80 |
| - |
81 |
| - if s_loaded != in_loaded: |
82 |
| - raise AssertionError("Expected JSON to be equal:", s_loaded, in_loaded) |
83 |
| - |
84 |
| - open(out, "wb").write(serialized) |
| 60 | + main() |
0 commit comments