Skip to content

Commit 130acff

Browse files
committed
Generate __init__.py files
1 parent dcb7102 commit 130acff

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

Pipfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ jinja2 = "*"
1818
python_version = "3.7"
1919

2020
[scripts]
21-
plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=."
21+
plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=output"
2222
generate = "python betterproto/tests/generate.py"
2323
test = "pytest ./betterproto/tests"

betterproto/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ def string_field(number: int, default: str = "") -> Any:
209209
return dataclass_field(number, TYPE_STRING, default=default)
210210

211211

212+
def bytes_field(number: int, default: bytes = b"") -> Any:
213+
return dataclass_field(number, TYPE_BYTES, default=default)
214+
215+
212216
def message_field(number: int) -> Any:
213217
return dataclass_field(number, TYPE_MESSAGE)
214218

protoc-gen-betterpy.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@ def py_type(
4040
message_type = descriptor.type_name.lstrip(".")
4141
if message_type.startswith(package):
4242
# This is the current package, which has nested types flattened.
43-
message_type = message_type.lstrip(package).lstrip(".").replace(".", "")
43+
message_type = (
44+
f'"{message_type.lstrip(package).lstrip(".").replace(".", "")}"'
45+
)
4446

4547
if "." in message_type:
4648
# This is imported from another package. No need
4749
# to use a forward ref and we need to add the import.
48-
message_type = message_type.strip('"')
4950
parts = message_type.split(".")
5051
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
5152
message_type = f"{parts[-2]}.{parts[-1]}"
@@ -58,7 +59,7 @@ def py_type(
5859
# file=sys.stderr,
5960
# )
6061

61-
return f'"{message_type}"'
62+
return message_type
6263
elif descriptor.type == 12:
6364
return "bytes"
6465
else:
@@ -247,10 +248,28 @@ def generate_code(request, response):
247248
# Fill response
248249
f = response.file.add()
249250
# print(filename, file=sys.stderr)
250-
f.name = filename + ".py"
251+
f.name = filename.replace(".", os.path.sep) + ".py"
252+
251253
# f.content = json.dumps(output, indent=2)
252254
f.content = template.render(description=output).rstrip("\n") + "\n"
253255

256+
inits = set([""])
257+
for f in response.file:
258+
# Ensure output paths exist
259+
print(f.name, file=sys.stderr)
260+
dirnames = os.path.dirname(f.name)
261+
if dirnames:
262+
os.makedirs(dirnames, exist_ok=True)
263+
base = ""
264+
for part in dirnames.split(os.path.sep):
265+
base = os.path.join(base, part)
266+
inits.add(base)
267+
268+
for base in inits:
269+
init = response.file.add()
270+
init.name = os.path.join(base, "__init__.py")
271+
init.content = b""
272+
254273

255274
if __name__ == "__main__":
256275
# Read request message from stdin

0 commit comments

Comments
 (0)