Skip to content

Commit 1002ae3

Browse files
rwgkcopybara-github
authored andcommitted
pybind11_protobuf:
`LOG(WARNING)` `FALL BACK TO PROTOBUF SERIALIZE/PARSE` based on `ExtensionsWithUnknownFieldsPolicy`. Low-level change in preparation for PyCLIF-pybind11 rollout. PiperOrigin-RevId: 604380353
1 parent 3b11990 commit 1002ae3

File tree

7 files changed

+51
-24
lines changed

7 files changed

+51
-24
lines changed

pybind11_protobuf/BUILD

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,6 @@ cc_library(
8787
name = "check_unknown_fields",
8888
srcs = ["check_unknown_fields.cc"],
8989
hdrs = ["check_unknown_fields.h"],
90-
visibility = [
91-
"//visibility:private",
92-
],
9390
deps = [
9491
"@com_google_absl//absl/container:flat_hash_map",
9592
"@com_google_absl//absl/container:flat_hash_set",

pybind11_protobuf/check_unknown_fields.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
183183

184184
std::optional<std::string> CheckRecursively(
185185
const ::google::protobuf::python::PyProto_API* py_proto_api,
186-
const ::google::protobuf::Message* message, bool build_error_message_if_any) {
186+
const ::google::protobuf::Message* message) {
187187
const auto* root_descriptor = message->GetDescriptor();
188188
HasUnknownFields search{py_proto_api, root_descriptor};
189189
if (!search.FindUnknownFieldsRecursive(message, 0u)) {
@@ -193,9 +193,6 @@ std::optional<std::string> CheckRecursively(
193193
search.FieldFQN())) != 0) {
194194
return std::nullopt;
195195
}
196-
if (!build_error_message_if_any) {
197-
return ""; // This indicates that an unknown field was found.
198-
}
199196
return search.BuildErrorMessage();
200197
}
201198

pybind11_protobuf/check_unknown_fields.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
4747

4848
std::optional<std::string> CheckRecursively(
4949
const ::google::protobuf::python::PyProto_API* py_proto_api,
50-
const ::google::protobuf::Message* top_message, bool build_error_message_if_any);
50+
const ::google::protobuf::Message* top_message);
5151

5252
} // namespace pybind11_protobuf::check_unknown_fields
5353

pybind11_protobuf/proto_cast_util.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <iostream>
1010
#include <memory>
1111
#include <string>
12+
#include <unordered_set>
1213
#include <utility>
1314
#include <vector>
1415

@@ -828,14 +829,18 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy,
828829

829830
std::optional<std::string> unknown_field_message =
830831
check_unknown_fields::CheckRecursively(
831-
GlobalState::instance()->py_proto_api(), src,
832-
check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::
833-
UnknownFieldsAreDisallowed());
832+
GlobalState::instance()->py_proto_api(), src);
834833
if (unknown_field_message) {
835-
if (!unknown_field_message->empty()) {
834+
if (check_unknown_fields::ExtensionsWithUnknownFieldsPolicy::
835+
UnknownFieldsAreDisallowed()) {
836836
throw py::value_error(*unknown_field_message);
837837
}
838-
// Fall back to serialize/parse.
838+
// Emit one LOG(WARNING) per unique unknown_field_message:
839+
static auto fall_back_log_shown = new std::unordered_set<std::string>();
840+
if (fall_back_log_shown->insert(*unknown_field_message).second) {
841+
LOG(WARNING) << "FALL BACK TO PROTOBUF SERIALIZE/PARSE: "
842+
<< *unknown_field_message;
843+
}
839844
return GenericPyProtoCast(src, policy, parent, is_const);
840845
}
841846

pybind11_protobuf/tests/BUILD

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,16 @@ pybind_extension(
161161
],
162162
)
163163

164+
EXTENSION_TEST_DEPS_COMMON = [
165+
":extension_in_other_file_in_deps_py_pb2",
166+
":extension_in_other_file_py_pb2",
167+
":extension_nest_repeated_py_pb2",
168+
":extension_py_pb2",
169+
":test_py_pb2", # fixdeps: keep - Direct dependency needed in open-source version, see https://github.com/grpc/grpc/issues/22811
170+
"@com_google_absl_py//absl/testing:absltest",
171+
"@com_google_absl_py//absl/testing:parameterized",
172+
]
173+
164174
py_test(
165175
name = "extension_test",
166176
srcs = ["extension_test.py"],
@@ -170,14 +180,20 @@ py_test(
170180
],
171181
python_version = "PY3",
172182
srcs_version = "PY3",
173-
deps = [
174-
":extension_in_other_file_in_deps_py_pb2",
175-
":extension_in_other_file_py_pb2",
176-
":extension_nest_repeated_py_pb2",
177-
":extension_py_pb2",
178-
":test_py_pb2", # fixdeps: keep - Direct dependency needed in open-source version, see https://github.com/grpc/grpc/issues/22811
179-
"@com_google_absl_py//absl/testing:absltest",
180-
"@com_google_absl_py//absl/testing:parameterized",
183+
deps = EXTENSION_TEST_DEPS_COMMON + ["@com_google_protobuf//:protobuf_python"],
184+
)
185+
186+
py_test(
187+
name = "extension_disallow_unknown_fields_test",
188+
srcs = ["extension_test.py"],
189+
data = [
190+
":extension_module.so",
191+
":proto_enum_module.so",
192+
],
193+
main = "extension_test.py",
194+
python_version = "PY3",
195+
srcs_version = "PY3",
196+
deps = EXTENSION_TEST_DEPS_COMMON + [
181197
"@com_google_protobuf//:protobuf_python",
182198
],
183199
)

pybind11_protobuf/tests/extension_module.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ void DefReserialize(py::module_& m, const char* py_name) {
4848
PYBIND11_MODULE(extension_module, m) {
4949
pybind11_protobuf::ImportNativeProtoCasters();
5050

51+
m.def("extensions_with_unknown_fields_are_disallowed", []() {
52+
return pybind11_protobuf::check_unknown_fields::
53+
ExtensionsWithUnknownFieldsPolicy::UnknownFieldsAreDisallowed();
54+
});
55+
5156
m.def("get_base_message", []() -> BaseMessage { return {}; });
5257

5358
m.def(

pybind11_protobuf/tests/extension_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
from pybind11_protobuf.tests import extension_pb2
2121

2222

23+
def unknown_field_exception_is_expected():
24+
return (
25+
api_implementation.Type() == 'cpp'
26+
and m.extensions_with_unknown_fields_are_disallowed()
27+
)
28+
29+
2330
def get_py_message(value=5,
2431
in_other_file_in_deps_value=None,
2532
in_other_file_value=None):
@@ -103,7 +110,7 @@ def test_extension_in_other_file_roundtrip(self):
103110

104111
def test_reserialize_base_message(self):
105112
a = get_py_message(in_other_file_value=63)
106-
if api_implementation.Type() == 'cpp':
113+
if unknown_field_exception_is_expected():
107114
with self.assertRaises(ValueError) as ctx:
108115
m.reserialize_base_message(a)
109116
self.assertStartsWith(
@@ -127,7 +134,7 @@ def test_reserialize_nest_level2(self):
127134
a = extension_pb2.NestLevel2(
128135
nest_lvl1=extension_pb2.NestLevel1(
129136
base_msg=get_py_message(in_other_file_value=52)))
130-
if api_implementation.Type() == 'cpp':
137+
if unknown_field_exception_is_expected():
131138
with self.assertRaises(ValueError) as ctx:
132139
m.reserialize_nest_level2(a)
133140
self.assertStartsWith(
@@ -154,7 +161,7 @@ def test_reserialize_nest_repeated(self):
154161
get_py_message(in_other_file_value=74),
155162
get_py_message(in_other_file_value=85)
156163
])
157-
if api_implementation.Type() == 'cpp':
164+
if unknown_field_exception_is_expected():
158165
with self.assertRaises(ValueError) as ctx:
159166
m.reserialize_nest_repeated(a)
160167
self.assertStartsWith(

0 commit comments

Comments
 (0)