Skip to content

Commit e323d45

Browse files
authored
fix(internal/librarian/rust): map recursive fields to Box (#4074)
Recursive fields that are not repeated or oneof should be mapped to Box.
1 parent 2b3750a commit e323d45

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

internal/sidekick/rust/annotate.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,8 @@ type fieldAnnotations struct {
505505
IsBoxed bool
506506
// If true, it requires a serde_with::serde_as() transformation.
507507
SerdeAs string
508+
// If true, the field is boxed in the prost generated type.
509+
MapToBoxed bool
508510
// If true, use `wkt::internal::is_default()` to skip the field
509511
SkipIfIsDefault bool
510512
// If true, this is a `wkt::Value` field, and requires super-extra custom
@@ -1479,6 +1481,7 @@ func (c *codec) annotateField(field *api.Field, message *api.Message, model *api
14791481
if field.Recursive || (field.Typez == api.MESSAGE_TYPE && field.IsOneOf) {
14801482
ann.IsBoxed = true
14811483
}
1484+
ann.MapToBoxed = mapToBoxed(field, message, model)
14821485
field.Codec = ann
14831486
if field.Typez == api.MESSAGE_TYPE {
14841487
if msg, ok := model.State.MessageByID[field.TypezID]; ok && msg.IsMap {
@@ -1651,3 +1654,39 @@ func isIdempotent(p *api.PathInfo) string {
16511654
}
16521655
return "true"
16531656
}
1657+
1658+
// mapToBoxed returns true if the prost generated type for this field is boxed.
1659+
// Prost boxes fields that would cause an infinitely sized struct, which happens
1660+
// on recursive cycles that are not broken by a repeated or map field.
1661+
func mapToBoxed(field *api.Field, message *api.Message, model *api.API) bool {
1662+
if field.Typez != api.MESSAGE_TYPE || field.Repeated || field.Map {
1663+
return false
1664+
}
1665+
1666+
var check func(typezID string, targetID string, visited map[string]bool) bool
1667+
check = func(typezID string, targetID string, visited map[string]bool) bool {
1668+
if typezID == targetID {
1669+
return true
1670+
}
1671+
if visited[typezID] {
1672+
return false
1673+
}
1674+
visited[typezID] = true
1675+
msg, ok := model.State.MessageByID[typezID]
1676+
if !ok {
1677+
return false
1678+
}
1679+
for _, f := range msg.Fields {
1680+
if f.Typez != api.MESSAGE_TYPE || f.Repeated || f.Map {
1681+
continue
1682+
}
1683+
if check(f.TypezID, targetID, visited) {
1684+
return true
1685+
}
1686+
}
1687+
return false
1688+
}
1689+
1690+
visited := make(map[string]bool)
1691+
return check(field.TypezID, message.ID, visited)
1692+
}

internal/sidekick/rust/annotate_field_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ func TestFieldAnnotations(t *testing.T) {
183183
PrimitiveFieldType: "crate::model::TestMessage",
184184
AddQueryParameter: `let builder = req.boxed_field.as_ref().map(|p| serde_json::to_value(p).map_err(Error::ser) ).transpose()?.into_iter().fold(builder, |builder, v| { use gaxi::query_parameter::QueryParameter; v.add(builder, "boxedField") });`,
185185
IsBoxed: true,
186+
MapToBoxed: true,
186187
SkipIfIsDefault: true,
187188
FieldTypeIsParentType: true,
188189
}
@@ -295,6 +296,7 @@ func TestRecursiveFieldAnnotations(t *testing.T) {
295296
ValueField: value_field,
296297
SerdeAs: "std::collections::HashMap<wkt::internal::I32, serde_with::Same>",
297298
IsBoxed: true,
299+
MapToBoxed: true,
298300
SkipIfIsDefault: true,
299301
FieldTypeIsParentType: true,
300302
}
@@ -318,6 +320,7 @@ func TestRecursiveFieldAnnotations(t *testing.T) {
318320
PrimitiveFieldType: "crate::model::TestMessage",
319321
AddQueryParameter: `let builder = req.oneof_field().map(|p| serde_json::to_value(p).map_err(Error::ser) ).transpose()?.into_iter().fold(builder, |builder, p| { use gaxi::query_parameter::QueryParameter; p.add(builder, "oneofField") });`,
320322
IsBoxed: true,
323+
MapToBoxed: true,
321324
SkipIfIsDefault: true,
322325
OtherFieldsInGroup: []*api.Field{},
323326
FieldTypeIsParentType: true,
@@ -342,6 +345,7 @@ func TestRecursiveFieldAnnotations(t *testing.T) {
342345
PrimitiveFieldType: "crate::model::TestMessage",
343346
AddQueryParameter: `let builder = req.repeated_field.as_ref().map(|p| serde_json::to_value(p).map_err(Error::ser) ).transpose()?.into_iter().fold(builder, |builder, v| { use gaxi::query_parameter::QueryParameter; v.add(builder, "repeatedField") });`,
344347
IsBoxed: true,
348+
MapToBoxed: false,
345349
SkipIfIsDefault: true,
346350
FieldTypeIsParentType: true,
347351
}
@@ -365,6 +369,7 @@ func TestRecursiveFieldAnnotations(t *testing.T) {
365369
PrimitiveFieldType: "crate::model::TestMessage",
366370
AddQueryParameter: `let builder = { use gaxi::query_parameter::QueryParameter; serde_json::to_value(&req.message_field).map_err(Error::ser)?.add(builder, "messageField") };`,
367371
IsBoxed: true,
372+
MapToBoxed: true,
368373
SkipIfIsDefault: true,
369374
FieldTypeIsParentType: true,
370375
}
@@ -508,6 +513,7 @@ func TestSameTypeNameFieldAnnotations(t *testing.T) {
508513
PrimitiveFieldType: "rusty_test_inner_v1::model::TestMessage",
509514
AddQueryParameter: `let builder = req.oneof_field().map(|p| serde_json::to_value(p).map_err(Error::ser) ).transpose()?.into_iter().fold(builder, |builder, p| { use gaxi::query_parameter::QueryParameter; p.add(builder, "oneofField") });`,
510515
IsBoxed: true,
516+
MapToBoxed: false,
511517
SkipIfIsDefault: true,
512518
OtherFieldsInGroup: []*api.Field{},
513519
AliasInExamples: "OneofField",

internal/sidekick/rust/templates/convert-prost/message.mustache

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ impl gaxi::prost::ToProto<{{Codec.RelativeName}}> for {{Codec.QualifiedName}} {
3535
{{Codec.FieldName}}: self.{{Codec.FieldName}}.to_proto()?,
3636
{{/Optional}}
3737
{{#Optional}}
38+
{{^Codec.MapToBoxed}}
3839
{{Codec.FieldName}}: self.{{Codec.FieldName}}.map(|v| v.to_proto()).transpose()?,
40+
{{/Codec.MapToBoxed}}
41+
{{#Codec.MapToBoxed}}
42+
{{Codec.FieldName}}: self.{{Codec.FieldName}}.map(|v| v.to_proto().map(std::boxed::Box::new)).transpose()?,
43+
{{/Codec.MapToBoxed}}
3944
{{/Optional}}
4045
{{/Singular}}
4146
{{#Repeated}}

0 commit comments

Comments
 (0)