Skip to content

Commit 4de4429

Browse files
committed
mpgen: support primitive std::optional struct fields
Currently optional primitive fields like `std::optional<int>` are not well supported as struct members. Non-primitive optional fields like `std::optional<std::string>` and optional struct fields are well-supported because Cap'n Proto allows non-primitive fields to be unset, but primitive fields are always considered set so there is natural way to represent null values. Libmultiprocess does already support primitive optional method parameters and result values, by allowing the .capnp files to declare extra boolean parameters prefixed with "has" and treating the extra boolean parameters as indicators of whether options are set or unset. This commit just this functionality to work for struct members as well. For example a C++ `std::optional<int> param` parameter can be represented by 'param :Int32, hasParam :Bool` parameters in a .capnp file and libmultiprocess will use both Cap'n Proto fields together to represent the C++ value. Now C++ struct fields can be represented the same way (see unit changes test for an example).
1 parent d4ee75d commit 4de4429

File tree

5 files changed

+36
-21
lines changed

5 files changed

+36
-21
lines changed

src/mp/gen.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ struct FieldList
145145
std::map<kj::StringPtr, int> field_idx; // name -> args index
146146
bool has_result = false;
147147

148-
void addField(const ::capnp::StructSchema::Field& schema_field, bool param)
148+
void addField(const ::capnp::StructSchema::Field& schema_field, bool param, bool result)
149149
{
150150
if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) {
151151
return;
@@ -160,7 +160,8 @@ struct FieldList
160160
if (param) {
161161
field.param = schema_field;
162162
field.param_is_set = true;
163-
} else {
163+
}
164+
if (result) {
164165
field.result = schema_field;
165166
field.result_is_set = true;
166167
}
@@ -433,6 +434,13 @@ static void Generate(kj::StringPtr src_prefix,
433434

434435
if (node.getProto().isStruct()) {
435436
const auto& struc = node.asStruct();
437+
438+
FieldList fields;
439+
for (const auto schema_field : struc.getFields()) {
440+
fields.addField(schema_field, true, true);
441+
}
442+
fields.mergeFields();
443+
436444
std::ostringstream generic_name;
437445
generic_name << node_name;
438446
dec << "template<";
@@ -453,22 +461,19 @@ static void Generate(kj::StringPtr src_prefix,
453461
dec << "struct ProxyStruct<" << message_namespace << "::" << generic_name.str() << ">\n";
454462
dec << "{\n";
455463
dec << " using Struct = " << message_namespace << "::" << generic_name.str() << ";\n";
456-
for (const auto field : struc.getFields()) {
457-
auto field_name = field.getProto().getName();
464+
for (const auto& field : fields.fields) {
465+
if (field.skip) continue;
466+
auto field_name = field.param.getProto().getName();
458467
add_accessor(field_name);
459-
dec << " using " << Cap(field_name) << "Accessor = Accessor<" << base_name
460-
<< "_fields::" << Cap(field_name) << ", FIELD_IN | FIELD_OUT";
461-
if (BoxedType(field.getType())) dec << " | FIELD_BOXED";
462-
dec << ">;\n";
468+
dec << " using " << Cap(field_name) << "Accessor = "
469+
<< AccessorType(base_name, field) << ";\n";
463470
}
464471
dec << " using Accessors = std::tuple<";
465472
size_t i = 0;
466-
for (const auto field : struc.getFields()) {
467-
if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
468-
continue;
469-
}
473+
for (const auto& field : fields.fields) {
474+
if (field.skip) continue;
470475
if (i) dec << ", ";
471-
dec << Cap(field.getProto().getName()) << "Accessor";
476+
dec << Cap(field.param.getProto().getName()) << "Accessor";
472477
++i;
473478
}
474479
dec << ">;\n";
@@ -482,13 +487,11 @@ static void Generate(kj::StringPtr src_prefix,
482487
inl << "public:\n";
483488
inl << " using Struct = " << message_namespace << "::" << node_name << ";\n";
484489
size_t i = 0;
485-
for (const auto field : struc.getFields()) {
486-
if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
487-
continue;
488-
}
489-
auto field_name = field.getProto().getName();
490+
for (const auto& field : fields.fields) {
491+
if (field.skip) continue;
492+
auto field_name = field.param.getProto().getName();
490493
auto member_name = field_name;
491-
GetAnnotationText(field.getProto(), NAME_ANNOTATION_ID, &member_name);
494+
GetAnnotationText(field.param.getProto(), NAME_ANNOTATION_ID, &member_name);
492495
inl << " static decltype(auto) get(std::integral_constant<size_t, " << i << ">) { return "
493496
<< "&" << proxied_class_type << "::" << member_name << "; }\n";
494497
++i;
@@ -533,10 +536,10 @@ static void Generate(kj::StringPtr src_prefix,
533536

534537
FieldList fields;
535538
for (const auto schema_field : method.getParamType().getFields()) {
536-
fields.addField(schema_field, true);
539+
fields.addField(schema_field, true, false);
537540
}
538541
for (const auto schema_field : method.getResultType().getFields()) {
539-
fields.addField(schema_field, false);
542+
fields.addField(schema_field, false, true);
540543
}
541544
fields.mergeFields();
542545

test/mp/test/foo-types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <mp/type-map.h>
2020
#include <mp/type-message.h>
2121
#include <mp/type-number.h>
22+
#include <mp/type-optional.h>
2223
#include <mp/type-set.h>
2324
#include <mp/type-string.h>
2425
#include <mp/type-struct.h>

test/mp/test/foo.capnp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ struct FooStruct $Proxy.wrap("mp::test::FooStruct") {
5353
name @0 :Text;
5454
setint @1 :List(Int32);
5555
vbool @2 :List(Bool);
56+
optionalInt @3 :Int32 $Proxy.name("optional_int");
57+
hasOptionalInt @4 :Bool;
5658
}
5759

5860
struct FooCustom $Proxy.wrap("mp::test::FooCustom") {

test/mp/test/foo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <functional>
1010
#include <map>
1111
#include <memory>
12+
#include <optional>
1213
#include <string>
1314
#include <set>
1415
#include <vector>
@@ -21,6 +22,7 @@ struct FooStruct
2122
std::string name;
2223
std::set<int> setint;
2324
std::vector<bool> vbool;
25+
std::optional<int> optional_int;
2426
};
2527

2628
enum class FooEnum : uint8_t { ONE = 1, TWO = 2, };

test/mp/test/test.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ KJ_TEST("Call FooInterface methods")
141141
in.vbool.push_back(false);
142142
in.vbool.push_back(true);
143143
in.vbool.push_back(false);
144+
in.optional_int = 3;
144145
FooStruct out = foo->pass(in);
145146
KJ_EXPECT(in.name == out.name);
146147
KJ_EXPECT(in.setint.size() == out.setint.size());
@@ -151,6 +152,12 @@ KJ_TEST("Call FooInterface methods")
151152
for (size_t i = 0; i < in.vbool.size(); ++i) {
152153
KJ_EXPECT(in.vbool[i] == out.vbool[i]);
153154
}
155+
KJ_EXPECT(in.optional_int == out.optional_int);
156+
157+
// Additional checks for std::optional member
158+
KJ_EXPECT(foo->pass(in).optional_int == 3);
159+
in.optional_int.reset();
160+
KJ_EXPECT(!foo->pass(in).optional_int);
154161

155162
FooStruct err;
156163
try {

0 commit comments

Comments
 (0)