Skip to content
This repository was archived by the owner on Jun 9, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions src/betterproto2_compiler/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
FieldDescriptorProto,
FieldDescriptorProtoLabel,
FieldDescriptorProtoType,
FieldDescriptorProtoType as FieldType,
FileDescriptorProto,
MethodDescriptorProto,
)
Expand Down Expand Up @@ -339,7 +340,7 @@ def is_map(proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProt
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
if message_type == map_entry:
for nested in parent_message.nested_type: # parent message
if nested.name.replace("_", "").lower() == map_entry and nested.options.map_entry:
if nested.name.replace("_", "").lower() == map_entry and nested.options and nested.options.map_entry:
return True
return False

Expand Down Expand Up @@ -382,7 +383,10 @@ def get_field_string(self) -> str:
"""Construct string representation of this field as a field."""
name = f"{self.py_name}"
field_args = ", ".join(([""] + self.betterproto_field_args) if self.betterproto_field_args else [])
betterproto_field_type = f"betterproto2.{self.field_type}_field({self.proto_obj.number}{field_args})"

betterproto_field_type = (
f"betterproto2.field({self.proto_obj.number}, betterproto2.{str(self.field_type)}{field_args})"
)
if self.py_name in dir(builtins):
self.parent.builtins_types.add(self.py_name)
return f'{name}: "{self.annotation}" = {betterproto_field_type}'
Expand All @@ -396,9 +400,9 @@ def betterproto_field_args(self) -> list[str]:
args.append("optional=True")
if self.repeated:
args.append("repeated=True")
if self.field_type == "enum":
if self.field_type == FieldType.TYPE_ENUM:
t = self.py_type
args.append(f"enum_default_value=lambda: {t}.try_value(0)")
args.append(f"default_factory=lambda: {t}.try_value(0)")
return args

@property
Expand Down Expand Up @@ -426,12 +430,13 @@ def repeated(self) -> bool:

@property
def optional(self) -> bool:
return self.proto_obj.proto3_optional or (self.field_type == "message" and not self.repeated)
# TODO not for maps
return self.proto_obj.proto3_optional or (self.field_type == FieldType.TYPE_MESSAGE and not self.repeated)

@property
def field_type(self) -> str:
"""String representation of proto field type."""
return FieldDescriptorProtoType(self.proto_obj.type).name.lower().replace("type_", "")
def field_type(self) -> FieldType:
# TODO it should be possible to remove constructor
return FieldType(self.proto_obj.type)

@property
def packed(self) -> bool:
Expand Down Expand Up @@ -540,13 +545,17 @@ def ready(self) -> None:

raise ValueError("can't find enum")

@property
def betterproto_field_args(self) -> list[str]:
return [f"betterproto2.{self.proto_k_type}", f"betterproto2.{self.proto_v_type}"]

@property
def field_type(self) -> str:
return "map"
def get_field_string(self) -> str:
"""Construct string representation of this field as a field."""
betterproto_field_type = (
f"betterproto2.field({self.proto_obj.number}, "
"betterproto2.TYPE_MAP, "
f"map_types=(betterproto2.{self.proto_k_type}, "
f"betterproto2.{self.proto_v_type}))"
)
if self.py_name in dir(builtins):
self.parent.builtins_types.add(self.py_name)
return f'{self.py_name}: "{self.annotation}" = {betterproto_field_type}'

@property
def annotation(self) -> str:
Expand Down
107 changes: 107 additions & 0 deletions tests/inputs/features/features.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
syntax = "proto3";

import "google/protobuf/timestamp.proto";
import "google/protobuf/duration.proto";
import "google/protobuf/wrappers.proto";

package features;

message Bar {
string name = 1;
}

message Foo {
string name = 1;
Bar child = 2;
}

enum Enum {
ZERO = 0;
ONE = 1;
}

message EnumMsg {
Enum enum = 1;
}

message Newer {
bool x = 1;
int32 y = 2;
string z = 3;
}

message Older {
bool x = 1;
}

message IntMsg {
int32 val = 1;
}

message OneofMsg {
oneof group1 {
int32 x = 1;
string y = 2;
}
oneof group2 {
IntMsg a = 3;
string b = 4;
}
}

message JsonCasingMsg {
int32 pascal_case = 1;
int32 camel_case = 2;
int32 snake_case = 3;
int32 kabob_case = 4;
}

message OptionalBoolMsg {
google.protobuf.BoolValue field = 1;
}

message OptionalDatetimeMsg {
google.protobuf.Timestamp field = 1;
}

message Empty {}

message TimeMsg {
google.protobuf.Timestamp timestamp = 1;
google.protobuf.Duration duration = 2;
}

message MsgA {
int32 some_int = 1;
double some_double = 2;
string some_str = 3;
bool some_bool = 4;
}

message MsgB {
int32 some_int = 1;
double some_double = 2;
string some_str = 3;
bool some_bool = 4;
int32 some_default_int = 5;
double some_default_double = 6;
string some_default_str = 7;
bool some_default_bool = 8;
}

message MsgC {
oneof group1 {
int32 int_field = 1;
string string_field = 2;
Empty empty_field = 3;
}
}

message MsgD {
repeated google.protobuf.Timestamp timestamps = 1;
}

message MsgE {
bool bool_field = 1;
optional int32 int_field = 2;
}
43 changes: 43 additions & 0 deletions tests/inputs/pickling/pickling.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
syntax = "proto3";

package pickling;

import "google/protobuf/any.proto";
import "google/protobuf/struct.proto";


message Test {}

message Fe {
string abc = 1;
}

message Fi {
string abc = 1;
}

message Fo {
string abc = 1;
}

message NestedData {
map<string, google.protobuf.Struct> struct_foo = 1;
map<string, google.protobuf.Any> map_str_any_bar = 2;
}

message Complex {
string foo_str = 1;
oneof grp {
Fe fe = 3;
Fi fi = 4;
Fo fo = 5;
}
NestedData nested_data = 6;
map<string, google.protobuf.Any> mapping = 7;
}

message PickledMessage {
bool foo = 1;
int32 bar = 2;
repeated string baz = 3;
}
7 changes: 7 additions & 0 deletions tests/inputs/stream_stream/stream_stream.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
syntax = "proto3";

package stream_stream;

message Message {
string body = 1;
}
Loading