Skip to content

Commit 3f519d4

Browse files
Fixes #23 again, a broken test made it seem the issue was fixed before.
1 parent dedead0 commit 3f519d4

File tree

6 files changed

+129
-74
lines changed

6 files changed

+129
-74
lines changed

CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
> `2.0.0` will be released once the interface is stable.
1616
1717
- Add support for gRPC and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83)
18-
- Switch from to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75)
19-
- Fix No arguments are generated for stub methods when using import with proto definition
18+
- Switch from `pipenv` to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75)
2019
- Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25)
2120

2221
- Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57)

betterproto/plugin.py

Lines changed: 82 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest
1212

1313
import betterproto
14-
from betterproto.compile.importing import get_type_reference
14+
from betterproto.compile.importing import get_type_reference, parse_source_type_name
1515
from betterproto.compile.naming import (
1616
pythonize_class_name,
1717
pythonize_field_name,
1818
pythonize_method_name,
1919
)
20+
from betterproto.lib.google.protobuf import ServiceDescriptorProto
2021

2122
try:
2223
# betterproto[compiler] specific dependencies
@@ -76,11 +77,12 @@ def get_py_zero(type_num: int) -> Union[str, float]:
7677
return zero
7778

7879

79-
# Todo: Keep information about nested hierarchy
8080
def traverse(proto_file):
81+
# Todo: Keep information about nested hierarchy
8182
def _traverse(path, items, prefix=""):
8283
for i, item in enumerate(items):
83-
# Adjust the name since we flatten the heirarchy.
84+
# Adjust the name since we flatten the hierarchy.
85+
# Todo: don't change the name, but include full name in returned tuple
8486
item.name = next_prefix = prefix + item.name
8587
yield item, path + [i]
8688

@@ -162,17 +164,21 @@ def generate_code(request, response):
162164
output_package_content["template_data"] = template_data
163165

164166
# Read Messages and Enums
167+
output_types = []
165168
for output_package_name, output_package_content in output_package_files.items():
166169
for proto_file in output_package_content["files"]:
167170
for item, path in traverse(proto_file):
168-
read_protobuf_object(item, path, proto_file, output_package_content)
171+
type_data = read_protobuf_type(
172+
item, path, proto_file, output_package_content
173+
)
174+
output_types.append(type_data)
169175

170176
# Read Services
171177
for output_package_name, output_package_content in output_package_files.items():
172178
for proto_file in output_package_content["files"]:
173179
for index, service in enumerate(proto_file.service):
174180
read_protobuf_service(
175-
service, index, proto_file, output_package_content
181+
service, index, proto_file, output_package_content, output_types
176182
)
177183

178184
# Render files
@@ -214,63 +220,31 @@ def generate_code(request, response):
214220
print(f"Writing {output_package_name}", file=sys.stderr)
215221

216222

217-
def read_protobuf_service(service: DescriptorProto, index, proto_file, content):
218-
input_package_name = content["input_package"]
219-
template_data = content["template_data"]
220-
# print(service, file=sys.stderr)
221-
data = {
222-
"name": service.name,
223-
"py_name": pythonize_class_name(service.name),
224-
"comment": get_comment(proto_file, [6, index]),
225-
"methods": [],
226-
}
227-
for j, method in enumerate(service.method):
228-
input_message = None
229-
input_type = get_type_reference(
230-
input_package_name, template_data["imports"], method.input_type
231-
).strip('"')
232-
for msg in template_data["messages"]:
233-
if msg["name"] == input_type:
234-
input_message = msg
235-
for field in msg["properties"]:
236-
if field["zero"] == "None":
237-
template_data["typing_imports"].add("Optional")
238-
break
223+
def lookup_method_input_type(method, types):
224+
package, name = parse_source_type_name(method.input_type)
239225

240-
data["methods"].append(
241-
{
242-
"name": method.name,
243-
"py_name": pythonize_method_name(method.name),
244-
"comment": get_comment(proto_file, [6, index, 2, j], indent=8),
245-
"route": f"/{input_package_name}.{service.name}/{method.name}",
246-
"input": get_type_reference(
247-
input_package_name, template_data["imports"], method.input_type
248-
).strip('"'),
249-
"input_message": input_message,
250-
"output": get_type_reference(
251-
input_package_name,
252-
template_data["imports"],
253-
method.output_type,
254-
unwrap=False,
255-
),
256-
"client_streaming": method.client_streaming,
257-
"server_streaming": method.server_streaming,
258-
}
259-
)
226+
for known_type in types:
227+
if known_type["type"] != "Message":
228+
continue
260229

261-
if method.client_streaming:
262-
template_data["typing_imports"].add("AsyncIterable")
263-
template_data["typing_imports"].add("Iterable")
264-
template_data["typing_imports"].add("Union")
265-
if method.server_streaming:
266-
template_data["typing_imports"].add("AsyncIterator")
267-
template_data["services"].append(data)
230+
# Nested types are currently flattened without dots.
231+
# Todo: keep a fully quantified name in types, that is comparable with method.input_type
232+
if (
233+
package == known_type["package"]
234+
and name.replace(".", "") == known_type["name"]
235+
):
236+
return known_type
268237

269238

270-
def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, content):
239+
def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content):
271240
input_package_name = content["input_package"]
272241
template_data = content["template_data"]
273-
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
242+
data = {
243+
"name": item.name,
244+
"py_name": pythonize_class_name(item.name),
245+
"descriptor": item,
246+
"package": input_package_name,
247+
}
274248
if isinstance(item, DescriptorProto):
275249
# print(item, file=sys.stderr)
276250
if item.options.map_entry:
@@ -373,6 +347,7 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con
373347
# print(f, file=sys.stderr)
374348

375349
template_data["messages"].append(data)
350+
return data
376351
elif isinstance(item, EnumDescriptorProto):
377352
# print(item.name, path, file=sys.stderr)
378353
data.update(
@@ -391,6 +366,57 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con
391366
)
392367

393368
template_data["enums"].append(data)
369+
return data
370+
371+
372+
def read_protobuf_service(
373+
service: ServiceDescriptorProto, index, proto_file, content, output_types
374+
):
375+
input_package_name = content["input_package"]
376+
template_data = content["template_data"]
377+
# print(service, file=sys.stderr)
378+
data = {
379+
"name": service.name,
380+
"py_name": pythonize_class_name(service.name),
381+
"comment": get_comment(proto_file, [6, index]),
382+
"methods": [],
383+
}
384+
for j, method in enumerate(service.method):
385+
method_input_message = lookup_method_input_type(method, output_types)
386+
387+
if method_input_message:
388+
for field in method_input_message["properties"]:
389+
if field["zero"] == "None":
390+
template_data["typing_imports"].add("Optional")
391+
392+
data["methods"].append(
393+
{
394+
"name": method.name,
395+
"py_name": pythonize_method_name(method.name),
396+
"comment": get_comment(proto_file, [6, index, 2, j], indent=8),
397+
"route": f"/{input_package_name}.{service.name}/{method.name}",
398+
"input": get_type_reference(
399+
input_package_name, template_data["imports"], method.input_type
400+
).strip('"'),
401+
"input_message": method_input_message,
402+
"output": get_type_reference(
403+
input_package_name,
404+
template_data["imports"],
405+
method.output_type,
406+
unwrap=False,
407+
),
408+
"client_streaming": method.client_streaming,
409+
"server_streaming": method.server_streaming,
410+
}
411+
)
412+
413+
if method.client_streaming:
414+
template_data["typing_imports"].add("AsyncIterable")
415+
template_data["typing_imports"].add("Iterable")
416+
template_data["typing_imports"].add("Union")
417+
if method.server_streaming:
418+
template_data["typing_imports"].add("AsyncIterator")
419+
template_data["services"].append(data)
394420

395421

396422
def main():
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
syntax = "proto3";
2+
3+
package child;
4+
5+
message ChildRequestMessage {
6+
int32 child_argument = 1;
7+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
syntax = "proto3";
22

33
import "request_message.proto";
4+
import "child_package_request_message.proto";
45

56
// Tests generated service correctly imports the RequestMessage
67

78
service Test {
89
rpc DoThing (RequestMessage) returns (RequestResponse);
10+
rpc DoThing2 (child.ChildRequestMessage) returns (RequestResponse);
11+
rpc DoThing3 (Nested.RequestMessage) returns (RequestResponse);
912
}
1013

1114

1215
message RequestResponse {
1316
int32 value = 1;
1417
}
1518

19+
message Nested {
20+
message RequestMessage {
21+
int32 nestedArgument = 1;
22+
}
23+
}

betterproto/tests/inputs/import_service_input_message/test_import_service.py

Lines changed: 0 additions & 16 deletions
This file was deleted.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
3+
from betterproto.tests.mocks import MockChannel
4+
from betterproto.tests.output_betterproto.import_service_input_message import (
5+
RequestResponse,
6+
TestStub,
7+
)
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_service_correctly_imports_reference_message():
12+
mock_response = RequestResponse(value=10)
13+
service = TestStub(MockChannel([mock_response]))
14+
response = await service.do_thing(argument=1)
15+
assert mock_response == response
16+
17+
18+
@pytest.mark.asyncio
19+
async def test_service_correctly_imports_reference_message_from_child_package():
20+
mock_response = RequestResponse(value=10)
21+
service = TestStub(MockChannel([mock_response]))
22+
response = await service.do_thing2(child_argument=1)
23+
assert mock_response == response
24+
25+
26+
@pytest.mark.asyncio
27+
async def test_service_correctly_imports_nested_reference():
28+
mock_response = RequestResponse(value=10)
29+
service = TestStub(MockChannel([mock_response]))
30+
response = await service.do_thing3(nested_argument=1)
31+
assert mock_response == response

0 commit comments

Comments
 (0)