26
26
27
27
import builtins
28
28
import inspect
29
- import re
30
29
from collections .abc import Iterator
31
30
from dataclasses import (
32
31
dataclass ,
33
32
field ,
34
33
)
35
34
36
- import betterproto2
37
35
from betterproto2 import unwrap
38
36
39
37
from betterproto2_compiler .compile .importing import get_type_reference , parse_source_type_name
43
41
pythonize_field_name ,
44
42
pythonize_method_name ,
45
43
)
46
- from betterproto2_compiler .known_types import KNOWN_METHODS
44
+ from betterproto2_compiler .known_types import KNOWN_METHODS , WRAPPED_TYPES
47
45
from betterproto2_compiler .lib .google .protobuf import (
48
46
DescriptorProto ,
49
47
EnumDescriptorProto ,
@@ -318,8 +316,15 @@ def get_field_string(self) -> str:
318
316
@property
319
317
def betterproto_field_args (self ) -> list [str ]:
320
318
args = []
321
- if self .field_wraps :
322
- args .append (f"wraps={ self .field_wraps } " )
319
+
320
+ if self .proto_obj .type == FieldDescriptorProtoType .TYPE_MESSAGE :
321
+ type_package , type_name = parse_source_type_name (self .proto_obj .type_name , self .output_file .parent_request )
322
+
323
+ if (type_package , type_name ) in WRAPPED_TYPES :
324
+ # Without the lambda function, the type is evaluated right away, which fails since the corresponding
325
+ # import is placed at the end of the file to avoid circular imports.
326
+ args .append (f"wrap=lambda: { self .py_type } " )
327
+
323
328
if self .optional :
324
329
args .append ("optional=True" )
325
330
elif self .repeated :
@@ -338,16 +343,6 @@ def use_builtins(self) -> bool:
338
343
self .py_type == self .py_name and self .py_name in dir (builtins )
339
344
)
340
345
341
- @property
342
- def field_wraps (self ) -> str | None :
343
- """Returns betterproto wrapped field type or None."""
344
- match_wrapper = re .match (r"\.google\.protobuf\.(.+)Value$" , self .proto_obj .type_name )
345
- if match_wrapper :
346
- wrapped_type = "TYPE_" + match_wrapper .group (1 ).upper ()
347
- if hasattr (betterproto2 , wrapped_type ):
348
- return f"betterproto2.{ wrapped_type } "
349
- return None
350
-
351
346
@property
352
347
def repeated (self ) -> bool :
353
348
return self .proto_obj .label == FieldDescriptorProtoLabel .LABEL_REPEATED
@@ -405,6 +400,14 @@ def py_type(self) -> str:
405
400
@property
406
401
def annotation (self ) -> str :
407
402
py_type = self .py_type
403
+
404
+ # Replace by the wrapping type if needed
405
+ if self .proto_obj .type == FieldDescriptorProtoType .TYPE_MESSAGE :
406
+ type_package , type_name = parse_source_type_name (self .proto_obj .type_name , self .output_file .parent_request )
407
+
408
+ if wrapped_type := WRAPPED_TYPES .get ((type_package , type_name )):
409
+ py_type = wrapped_type
410
+
408
411
if self .use_builtins :
409
412
py_type = f"builtins.{ py_type } "
410
413
if self .repeated :
0 commit comments