Skip to content

Commit 2f62189

Browse files
authored
Fix typing and datetime imports not being present for service method type annotations (#183)
1 parent 8a21536 commit 2f62189

File tree

4 files changed

+74
-14
lines changed

4 files changed

+74
-14
lines changed

src/betterproto/plugin/models.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -324,17 +324,7 @@ def __post_init__(self) -> None:
324324
# Add field to message
325325
self.parent.fields.append(self)
326326
# Check for new imports
327-
annotation = self.annotation
328-
if "Optional[" in annotation:
329-
self.output_file.typing_imports.add("Optional")
330-
if "List[" in annotation:
331-
self.output_file.typing_imports.add("List")
332-
if "Dict[" in annotation:
333-
self.output_file.typing_imports.add("Dict")
334-
if "timedelta" in annotation:
335-
self.output_file.datetime_imports.add("timedelta")
336-
if "datetime" in annotation:
337-
self.output_file.datetime_imports.add("datetime")
327+
self.add_imports_to(self.output_file)
338328
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
339329

340330
def get_field_string(self, indent: int = 4) -> str:
@@ -356,6 +346,33 @@ def betterproto_field_args(self) -> List[str]:
356346
args.append(f"wraps={self.field_wraps}")
357347
return args
358348

349+
@property
350+
def datetime_imports(self) -> Set[str]:
351+
imports = set()
352+
annotation = self.annotation
353+
# FIXME: false positives - e.g. `MyDatetimedelta`
354+
if "timedelta" in annotation:
355+
imports.add("timedelta")
356+
if "datetime" in annotation:
357+
imports.add("datetime")
358+
return imports
359+
360+
@property
361+
def typing_imports(self) -> Set[str]:
362+
imports = set()
363+
annotation = self.annotation
364+
if "Optional[" in annotation:
365+
imports.add("Optional")
366+
if "List[" in annotation:
367+
imports.add("List")
368+
if "Dict[" in annotation:
369+
imports.add("Dict")
370+
return imports
371+
372+
def add_imports_to(self, output_file: OutputTemplate) -> None:
373+
output_file.datetime_imports.update(self.datetime_imports)
374+
output_file.typing_imports.update(self.typing_imports)
375+
359376
@property
360377
def field_wraps(self) -> Optional[str]:
361378
"""Returns betterproto wrapped field type or None."""
@@ -577,11 +594,10 @@ def __post_init__(self) -> None:
577594
# Add method to service
578595
self.parent.methods.append(self)
579596

580-
# Check for Optional import
597+
# Check for imports
581598
if self.py_input_message:
582599
for f in self.py_input_message.fields:
583-
if f.default_value_string == "None":
584-
self.output_file.typing_imports.add("Optional")
600+
f.add_imports_to(self.output_file)
585601
if "Optional" in self.py_output_message_type:
586602
self.output_file.typing_imports.add("Optional")
587603
self.mutable_default_args # ensure this is called before rendering

tests/inputs/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"googletypes_response",
1515
"googletypes_response_embedded",
1616
"service",
17+
"service_separate_packages",
1718
"import_service_input_message",
1819
"googletypes_service_returns_empty",
1920
"googletypes_service_returns_googletype",
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
syntax = "proto3";
2+
3+
import "google/protobuf/duration.proto";
4+
import "google/protobuf/timestamp.proto";
5+
6+
package things.messages;
7+
8+
message DoThingRequest {
9+
string name = 1;
10+
11+
// use `repeated` so we can check if `List` is correctly imported
12+
repeated string comments = 2;
13+
14+
// use google types `timestamp` and `duration` so we can check
15+
// if everything from `datetime` is correctly imported
16+
google.protobuf.Timestamp when = 3;
17+
google.protobuf.Duration duration = 4;
18+
}
19+
20+
message DoThingResponse {
21+
repeated string names = 1;
22+
}
23+
24+
message GetThingRequest {
25+
string name = 1;
26+
}
27+
28+
message GetThingResponse {
29+
string name = 1;
30+
int32 version = 2;
31+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
syntax = "proto3";
2+
3+
import "messages.proto";
4+
5+
package things.service;
6+
7+
service Test {
8+
rpc DoThing (things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
9+
rpc DoManyThings (stream things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
10+
rpc GetThingVersions (things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
11+
rpc GetDifferentThings (stream things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
12+
}

0 commit comments

Comments
 (0)