Skip to content

Commit 3f77b28

Browse files
rwgkcopybara-github
authored andcommitted
Fix (also replace - with _) and rename PythonPackageForDescriptor().
Add tests covering the situations in which the problem was observed. PiperOrigin-RevId: 623374809
1 parent b4a2e87 commit 3f77b28

File tree

8 files changed

+200
-9
lines changed

8 files changed

+200
-9
lines changed

pybind11_protobuf/proto_cast_util.cc

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
#include "absl/log/check.h"
1919
#include "absl/log/log.h"
2020
#include "absl/memory/memory.h"
21+
#include "absl/strings/match.h"
22+
#include "absl/strings/str_cat.h"
2123
#include "absl/strings/str_replace.h"
2224
#include "absl/strings/string_view.h"
25+
#include "absl/strings/strip.h"
2326
#include "absl/types/optional.h"
2427
#include "google/protobuf/descriptor.h"
2528
#include "google/protobuf/descriptor_database.h"
@@ -46,16 +49,24 @@ using ::google::protobuf::python::PyProto_API;
4649
using ::google::protobuf::python::PyProtoAPICapsuleName;
4750

4851
namespace pybind11_protobuf {
49-
namespace {
5052

51-
std::string PythonPackageForDescriptor(const FileDescriptor* file) {
52-
std::vector<std::pair<const absl::string_view, std::string>> replacements;
53-
replacements.emplace_back("/", ".");
54-
replacements.emplace_back(".proto", "_pb2");
55-
std::string name = file->name();
56-
return absl::StrReplaceAll(name, replacements);
53+
std::string StripProtoSuffixFromDescriptorFileName(absl::string_view filename) {
54+
if (absl::EndsWith(filename, ".protodevel")) {
55+
return std::string(absl::StripSuffix(filename, ".protodevel"));
56+
} else {
57+
return std::string(absl::StripSuffix(filename, ".proto"));
58+
}
59+
}
60+
61+
std::string InferPythonModuleNameFromDescriptorFileName(
62+
absl::string_view filename) {
63+
std::string basename = StripProtoSuffixFromDescriptorFileName(filename);
64+
absl::StrReplaceAll({{"-", "_"}, {"/", "."}}, &basename);
65+
return absl::StrCat(basename, "_pb2");
5766
}
5867

68+
namespace {
69+
5970
// Resolves the class name of a descriptor via d->containing_type()
6071
py::object ResolveDescriptor(py::object p, const Descriptor* d) {
6172
return d->containing_type() ? ResolveDescriptor(p, d->containing_type())
@@ -299,7 +310,8 @@ py::module_ GlobalState::ImportCached(const std::string& module_name) {
299310
}
300311

301312
py::object GlobalState::PyMessageInstance(const Descriptor* descriptor) {
302-
auto module_name = PythonPackageForDescriptor(descriptor->file());
313+
auto module_name =
314+
InferPythonModuleNameFromDescriptorFileName(descriptor->file()->name());
303315
if (!module_name.empty()) {
304316
auto cached = import_cache_.find(module_name);
305317
if (cached != import_cache_.end()) {
@@ -567,7 +579,8 @@ void InitializePybindProtoCastUtil() {
567579
void ImportProtoDescriptorModule(const Descriptor* descriptor) {
568580
assert(PyGILState_Check());
569581
if (!descriptor) return;
570-
auto module_name = PythonPackageForDescriptor(descriptor->file());
582+
auto module_name =
583+
InferPythonModuleNameFromDescriptorFileName(descriptor->file()->name());
571584
if (module_name.empty()) return;
572585
try {
573586
GlobalState::instance()->ImportCached(module_name);

pybind11_protobuf/proto_cast_util.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@
2929

3030
namespace pybind11_protobuf {
3131

32+
// Strips ".proto" or ".protodevel" from the end of a filename.
33+
// Similar to
34+
// https://github.com/protocolbuffers/protobuf/blob/b375d010bf57a6d673125330ec47f6e6a7e03f5c/src/google/protobuf/compiler/code_generator.cc#L129-L136
35+
// which is not public, unfortunately. Providing a public function here until
36+
// that situation changes.
37+
std::string StripProtoSuffixFromDescriptorFileName(absl::string_view filename);
38+
39+
// Returns the Python module name expected for a given .proto filename.
40+
// Similar to
41+
// https://github.com/protocolbuffers/protobuf/blob/b375d010bf57a6d673125330ec47f6e6a7e03f5c/src/google/protobuf/compiler/python/helpers.cc#L31-L35
42+
// which is not public, unfortunately. Providing a public function here until
43+
// that situation changes.
44+
std::string InferPythonModuleNameFromDescriptorFileName(
45+
absl::string_view filename);
46+
3247
// Simple helper. Caller has to ensure that the py_bytes argument outlives the
3348
// returned string_view.
3449
absl::string_view PyBytesAsStringView(pybind11::bytes py_bytes);

pybind11_protobuf/tests/BUILD

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,21 @@ py_proto_library(
9898
deps = [":extension_in_other_file_proto"],
9999
)
100100

101+
proto_library(
102+
name = "we-love-dashes_proto",
103+
srcs = ["we-love-dashes.proto"],
104+
)
105+
106+
cc_proto_library(
107+
name = "we_love_dashes_cc_proto",
108+
deps = [":we-love-dashes_proto"],
109+
)
110+
111+
py_proto_library(
112+
name = "we-love-dashes_py_pb2",
113+
deps = [":we-love-dashes_proto"],
114+
)
115+
101116
# Tests for enum_type_caster
102117

103118
pybind_extension(
@@ -342,3 +357,47 @@ py_test(
342357
requirement("absl_py"),
343358
],
344359
)
360+
361+
pybind_extension(
362+
name = "we_love_dashes_module",
363+
srcs = ["we_love_dashes_module.cc"],
364+
deps = [
365+
":we_love_dashes_cc_proto",
366+
"//pybind11_protobuf:native_proto_caster",
367+
],
368+
)
369+
370+
py_test(
371+
name = "we_love_dashes_cc_only_test",
372+
srcs = ["we_love_dashes_cc_only_test.py"],
373+
deps = [
374+
":we_love_dashes_module",
375+
"@com_google_absl_py//absl/testing:absltest",
376+
"@com_google_protobuf//:protobuf_python",
377+
requirement("absl_py"),
378+
],
379+
)
380+
381+
py_test(
382+
name = "we_love_dashes_cc_and_py_in_deps_test",
383+
srcs = ["we_love_dashes_cc_and_py_in_deps_test.py"],
384+
deps = [
385+
":we_love_dashes_module",
386+
":we-love-dashes_py_pb2", # fixdeps: keep
387+
"@com_google_absl_py//absl/testing:absltest",
388+
"@com_google_protobuf//:protobuf_python",
389+
requirement("absl_py"),
390+
],
391+
)
392+
393+
py_test(
394+
name = "we_love_dashes_py_only_test",
395+
srcs = ["we_love_dashes_py_only_test.py"],
396+
deps = [
397+
":very_large_proto_module",
398+
":we-love-dashes_py_pb2",
399+
"@com_google_absl_py//absl/testing:absltest",
400+
"@com_google_protobuf//:protobuf_python",
401+
requirement("absl_py"),
402+
],
403+
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
syntax = "proto2";
2+
3+
package pybind11.test;
4+
5+
message TokenEffort {
6+
optional int32 score = 1;
7+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2024 The Pybind Development Team. All rights reserved.
2+
#
3+
# All rights reserved. Use of this source code is governed by a
4+
# BSD-style license that can be found in the LICENSE file.
5+
6+
from absl.testing import absltest
7+
from pybind11_protobuf.tests import we_love_dashes_module
8+
9+
# NOTE: ":we-love-dashes_py_pb2" is in deps but intentionally not imported here.
10+
11+
12+
class MessageTest(absltest.TestCase):
13+
14+
def test_return_then_pass(self):
15+
msg = we_love_dashes_module.return_token_effort(234)
16+
score = we_love_dashes_module.pass_token_effort(msg)
17+
self.assertEqual(score, 234)
18+
19+
20+
if __name__ == '__main__':
21+
absltest.main()
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) 2024 The Pybind Development Team. All rights reserved.
2+
#
3+
# All rights reserved. Use of this source code is governed by a
4+
# BSD-style license that can be found in the LICENSE file.
5+
6+
from absl.testing import absltest
7+
from google.protobuf.internal import api_implementation
8+
from pybind11_protobuf.tests import we_love_dashes_module
9+
10+
11+
class MessageTest(absltest.TestCase):
12+
13+
def test_return_then_pass(self):
14+
if api_implementation.Type() == 'cpp':
15+
msg = we_love_dashes_module.return_token_effort(234)
16+
score = we_love_dashes_module.pass_token_effort(msg)
17+
self.assertEqual(score, 234)
18+
else:
19+
with self.assertRaisesRegex(
20+
TypeError,
21+
r'^Cannot construct a protocol buffer message type'
22+
r' pybind11\.test\.TokenEffort in python\.'
23+
r' .*pybind11_protobuf\.tests\.we_love_dashes_pb2\?$',
24+
):
25+
we_love_dashes_module.return_token_effort(0)
26+
27+
28+
if __name__ == '__main__':
29+
absltest.main()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2024 The Pybind Development Team. All rights reserved.
2+
//
3+
// All rights reserved. Use of this source code is governed by a
4+
// BSD-style license that can be found in the LICENSE file.
5+
6+
#include <pybind11/pybind11.h>
7+
8+
#include "pybind11_protobuf/native_proto_caster.h"
9+
#include "pybind11_protobuf/tests/we-love-dashes.pb.h"
10+
11+
namespace {
12+
13+
PYBIND11_MODULE(we_love_dashes_module, m) {
14+
pybind11_protobuf::ImportNativeProtoCasters();
15+
16+
m.def("return_token_effort", [](int score) {
17+
pybind11::test::TokenEffort msg;
18+
msg.set_score(score);
19+
return msg;
20+
});
21+
22+
m.def("pass_token_effort",
23+
[](const pybind11::test::TokenEffort& msg) { return msg.score(); });
24+
}
25+
26+
} // namespace
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2024 The Pybind Development Team. All rights reserved.
2+
#
3+
# All rights reserved. Use of this source code is governed by a
4+
# BSD-style license that can be found in the LICENSE file.
5+
6+
from absl.testing import absltest
7+
8+
from pybind11_protobuf.tests import very_large_proto_module
9+
from pybind11_protobuf.tests import we_love_dashes_pb2
10+
11+
12+
class MessageTest(absltest.TestCase):
13+
14+
def test_pass_proto2_message(self):
15+
msg = we_love_dashes_pb2.TokenEffort(score=345)
16+
space_used_estimate = very_large_proto_module.get_space_used_estimate(msg)
17+
self.assertGreater(space_used_estimate, 0)
18+
19+
20+
if __name__ == '__main__':
21+
absltest.main()

0 commit comments

Comments
 (0)