Skip to content

Commit e237fb2

Browse files
committed
feat(db): support key in nested message
Resolve key fields declared in nested messages (for example metadata.id) across key lookup, indexability checks, order planning, and field-reader path resolution. Add recursive protoc-gen validation to reject invalid/ambiguous nested key declarations and extend tests for nested-key behavior. Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
1 parent 068c5f9 commit e237fb2

File tree

9 files changed

+424
-58
lines changed

9 files changed

+424
-58
lines changed

cmd/protoc-gen-protodb/main.go

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ type mod struct {
4242
oneOfs map[string]struct{}
4343
}
4444

45+
type keyPath []pgs.Field
46+
4547
func (p *mod) Name() string {
4648
return "protodb"
4749
}
@@ -85,28 +87,87 @@ func (p *mod) Execute(targets map[string]pgs.File, _ map[string]pgs.Package) []p
8587
if !ok || !enabled {
8688
continue
8789
}
88-
var keyField string
89-
for _, f := range m.Fields() {
90-
var k bool
91-
f.Extension(protodb.E_Key, &k)
92-
if !k {
93-
continue
94-
}
95-
if keyField != "" {
96-
p.Failf("%s: %s: key already defined for message: %s", m.Name(), f.Name(), keyField)
97-
}
98-
keyField = f.Name().String()
99-
if f.Type().IsMap() || f.Type().IsRepeated() || f.InOneOf() || f.Descriptor().GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE {
100-
p.Failf("%s: %s: only non repeated and not oneof scalar types are supported as key, got %T", m.Name(), f.Name(), f.Type())
101-
}
102-
}
90+
p.validateMessageKey(m)
10391
msgs = append(msgs, m)
10492
}
10593
p.generate(f, msgs)
10694
}
10795
return p.Artifacts()
10896
}
10997

98+
func (p *mod) validateMessageKey(m pgs.Message) {
99+
keys := p.collectMessageKeys(m)
100+
if len(keys) == 0 {
101+
return
102+
}
103+
for _, path := range keys {
104+
leaf := path.leaf()
105+
if leaf.Type().IsMap() || leaf.Type().IsRepeated() || leaf.InOneOf() || leaf.Descriptor().GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE {
106+
p.Failf("%s: %s: only non repeated and not oneof scalar types are supported as key, got %T", m.Name(), path.String(), leaf.Type())
107+
}
108+
for i := 0; i < len(path)-1; i++ {
109+
f := path[i]
110+
if !f.Type().IsEmbed() || f.Type().IsMap() || f.Type().IsRepeated() || f.InOneOf() {
111+
p.Failf("%s: %s: nested key path must traverse singular message fields", m.Name(), path.String())
112+
}
113+
}
114+
}
115+
if len(keys) == 1 {
116+
return
117+
}
118+
paths := make([]string, 0, len(keys))
119+
for _, path := range keys {
120+
paths = append(paths, path.String())
121+
}
122+
p.Failf("%s: multiple keys found for message: %s", m.Name(), strings.Join(paths, ", "))
123+
}
124+
125+
func (p *mod) collectMessageKeys(m pgs.Message) []keyPath {
126+
stack := map[string]bool{m.FullyQualifiedName(): true}
127+
return p.collectMessageKeysIn(m, nil, stack)
128+
}
129+
130+
func (p *mod) collectMessageKeysIn(m pgs.Message, path keyPath, stack map[string]bool) []keyPath {
131+
var out []keyPath
132+
for _, f := range m.Fields() {
133+
next := make(keyPath, 0, len(path)+1)
134+
next = append(next, path...)
135+
next = append(next, f)
136+
var key bool
137+
f.Extension(protodb.E_Key, &key)
138+
if key {
139+
out = append(out, next)
140+
}
141+
if !f.Type().IsEmbed() || f.Type().IsMap() || f.Type().IsRepeated() || f.InOneOf() {
142+
continue
143+
}
144+
embed := f.Type().Embed()
145+
if embed == nil {
146+
continue
147+
}
148+
name := embed.FullyQualifiedName()
149+
if stack[name] {
150+
continue
151+
}
152+
stack[name] = true
153+
out = append(out, p.collectMessageKeysIn(embed, next, stack)...)
154+
delete(stack, name)
155+
}
156+
return out
157+
}
158+
159+
func (p keyPath) leaf() pgs.Field {
160+
return p[len(p)-1]
161+
}
162+
163+
func (p keyPath) String() string {
164+
parts := make([]string, 0, len(p))
165+
for _, f := range p {
166+
parts = append(parts, f.Name().String())
167+
}
168+
return strings.Join(parts, ".")
169+
}
170+
110171
func (p *mod) generate(f pgs.File, msgs []pgs.Message) {
111172
if len(msgs) == 0 {
112173
return

internal/db/tx.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ func buildOrderPlan(msg protoreflect.Message, orderBy *v1alpha1.OrderBy) (orderF
11741174
if !isOrderableFieldPath(fds) {
11751175
return orderField{}, fmt.Errorf("order_by field %q is not sortable", fieldPath)
11761176
}
1177-
isKey := hasKey && len(fds) == 1 && string(fds[0].Name()) == keyField
1177+
isKey := hasKey && nameFieldPath(fds) == keyField
11781178
if !isKey {
11791179
fd := fds[len(fds)-1]
11801180
if !proto.HasExtension(fd.Options(), protopts.E_Index) {
@@ -1319,3 +1319,11 @@ func hash(f protodb.Filter) (hash string, err error) {
13191319
h := sha512.Sum512(b)
13201320
return base64.StdEncoding.EncodeToString(h[:]), nil
13211321
}
1322+
1323+
func nameFieldPath(fds []protoreflect.FieldDescriptor) string {
1324+
parts := make([]string, 0, len(fds))
1325+
for _, fd := range fds {
1326+
parts = append(parts, string(fd.Name()))
1327+
}
1328+
return strings.Join(parts, ".")
1329+
}

internal/db/tx_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@ import (
66
"testing"
77

88
"github.com/dgraph-io/badger/v3"
9+
"github.com/jhump/protoreflect/v2/protobuilder"
910
"github.com/stretchr/testify/require"
1011
"go.linka.cloud/protofilters/filters"
12+
"google.golang.org/protobuf/proto"
13+
"google.golang.org/protobuf/reflect/protodesc"
1114
"google.golang.org/protobuf/reflect/protoreflect"
15+
"google.golang.org/protobuf/types/descriptorpb"
1216
"google.golang.org/protobuf/types/dynamicpb"
1317

1418
"go.linka.cloud/protodb/internal/protodb"
1519
"go.linka.cloud/protodb/internal/token"
20+
protopts "go.linka.cloud/protodb/protodb"
1621
v1alpha1 "go.linka.cloud/protodb/protodb/v1alpha1"
1722
)
1823

@@ -55,6 +60,18 @@ func TestBuildOrderPlan(t *testing.T) {
5560
_, err := buildOrderPlan(m.ProtoReflect(), &v1alpha1.OrderBy{Field: "status", Direction: v1alpha1.OrderDirection(99)})
5661
require.ErrorContains(t, err, "invalid direction")
5762
})
63+
64+
require.NoError(t, d.RegisterProto(ctx, buildNestedKeyOrderFileDescriptor(t)))
65+
nmd, err := lookupMessage(d, "tests.order.OrderDoc")
66+
require.NoError(t, err)
67+
nm := dynamicpb.NewMessage(nmd)
68+
69+
t.Run("nested key field", func(t *testing.T) {
70+
pl, err := buildOrderPlan(nm.ProtoReflect(), &v1alpha1.OrderBy{Field: "metadata.id"})
71+
require.NoError(t, err)
72+
require.Equal(t, "metadata.id", pl.fieldPath)
73+
require.Equal(t, v1alpha1.OrderDirectionAsc, pl.direction)
74+
})
5875
}
5976

6077
func TestGetContinuationValidation(t *testing.T) {
@@ -297,6 +314,30 @@ func lookupNumberField(md protoreflect.MessageDescriptor, n string) (protoreflec
297314
return nil, fmt.Errorf("field number %s not found", n)
298315
}
299316

317+
func buildNestedKeyOrderFileDescriptor(t *testing.T) *descriptorpb.FileDescriptorProto {
318+
t.Helper()
319+
keyOpts := &descriptorpb.FieldOptions{}
320+
proto.SetExtension(keyOpts, protopts.E_Key, true)
321+
322+
meta := protobuilder.NewMessage("Metadata").
323+
AddField(protobuilder.NewField("id", protobuilder.FieldTypeString()).SetNumber(1).SetOptions(keyOpts))
324+
325+
msg := protobuilder.NewMessage("OrderDoc").
326+
AddField(protobuilder.NewField("metadata", protobuilder.FieldTypeMessage(meta)).SetNumber(1)).
327+
AddField(protobuilder.NewField("status", protobuilder.FieldTypeString()).SetNumber(2)).
328+
AddNestedMessage(meta)
329+
330+
file := protobuilder.NewFile("tests/order_nested_key.proto").
331+
SetPackageName(protoreflect.FullName("tests.order")).
332+
SetSyntax(protoreflect.Proto3).
333+
AddMessage(msg).
334+
AddImportedDependency(protopts.File_protodb_protodb_proto)
335+
336+
fd, err := file.Build()
337+
require.NoError(t, err)
338+
return protodesc.ToFileDescriptorProto(fd)
339+
}
340+
300341
func TestGetWithMalformedTokenReturnsInvalid(t *testing.T) {
301342
ctx := context.Background()
302343
dbif, err := Open(ctx, WithPath(t.TempDir()), WithApplyDefaults(true))

internal/index/indexer.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -984,14 +984,14 @@ func isKeyField(md protoreflect.MessageDescriptor, fds []protoreflect.FieldDescr
984984
if md == nil {
985985
return false
986986
}
987-
if len(fds) != 1 {
987+
if len(fds) == 0 {
988988
return false
989989
}
990990
field, ok := protodb.KeyFieldName(md)
991991
if !ok {
992992
return false
993993
}
994-
return field == string(fds[0].Name())
994+
return field == fieldPathFromNames(fds)
995995
}
996996

997997
func isIndexableFieldPathDescriptors(fds []protoreflect.FieldDescriptor) bool {

internal/index/indexer_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ func TestIndexableFilter(t *testing.T) {
5252

5353
_, err = idx.IndexableFilter(m, filters.Where("meta.nope").StringEquals("x"))
5454
require.Error(t, err)
55+
56+
nestedKey := buildIndexerNestedKeyDescriptor(t)
57+
nestedMsg := dynamicpb.NewMessage(nestedKey)
58+
ok, err = idx.IndexableFilter(nestedMsg, filters.Where("metadata.id").StringEquals("x"))
59+
require.NoError(t, err)
60+
require.True(t, ok)
5561
}
5662

5763
func TestEnforceUnique(t *testing.T) {
@@ -349,6 +355,32 @@ func buildUniqueDocDescriptor(t *testing.T) protoreflect.MessageDescriptor {
349355
return md
350356
}
351357

358+
func buildIndexerNestedKeyDescriptor(t *testing.T) protoreflect.MessageDescriptor {
359+
t.Helper()
360+
keyOpts := &descriptorpb.FieldOptions{}
361+
proto.SetExtension(keyOpts, protopts.E_Key, true)
362+
363+
meta := protobuilder.NewMessage("Metadata").
364+
AddField(protobuilder.NewField("id", protobuilder.FieldTypeString()).SetNumber(1).SetOptions(keyOpts))
365+
366+
msg := protobuilder.NewMessage("NestedKeyDoc").
367+
AddField(protobuilder.NewField("metadata", protobuilder.FieldTypeMessage(meta)).SetNumber(1)).
368+
AddField(protobuilder.NewField("status", protobuilder.FieldTypeString()).SetNumber(2).SetOptions(&descriptorpb.FieldOptions{})).
369+
AddNestedMessage(meta)
370+
371+
file := protobuilder.NewFile("tests/indexer_nested_key.proto").
372+
SetPackageName(protoreflect.FullName("tests.index")).
373+
SetSyntax(protoreflect.Proto3).
374+
AddMessage(msg).
375+
AddImportedDependency(protopts.File_protodb_protodb_proto)
376+
377+
fd, err := file.Build()
378+
require.NoError(t, err)
379+
md := fd.Messages().ByName("NestedKeyDoc")
380+
require.NotNil(t, md)
381+
return md
382+
}
383+
352384
func buildUniqueTagsDescriptor(t *testing.T) protoreflect.MessageDescriptor {
353385
t.Helper()
354386
keyOpts := &descriptorpb.FieldOptions{}

internal/index/store.go

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ func (f *fieldReader) Get(_ context.Context, n protoreflect.Name) iter.Seq2[pfin
321321

322322
type keyField struct {
323323
name protoreflect.Name
324-
fd protoreflect.FieldDescriptor
324+
fds []protoreflect.FieldDescriptor
325325
dataPrefix []byte
326326
txn badgerd.Tx
327327
entries []keyEntry
@@ -371,7 +371,7 @@ func (k *keyField) iterate(yield func(pfindex.Field, error) bool) {
371371
}
372372
for _, e := range k.entries {
373373
entry := &field{
374-
descriptors: []protoreflect.FieldDescriptor{k.fd},
374+
descriptors: k.fds,
375375
value: protoreflect.ValueOfString(e.value),
376376
bitmap: e.bitmap,
377377
}
@@ -558,13 +558,7 @@ func fieldPathFromNames(fds []protoreflect.FieldDescriptor) string {
558558
}
559559

560560
func lookupByNumberPath(md protoreflect.MessageDescriptor, fieldPath string) ([]protoreflect.FieldDescriptor, error) {
561-
parts := strings.Split(fieldPath, ".")
562-
if len(parts) == 0 {
563-
return nil, fmt.Errorf("empty field path")
564-
}
565-
var fds []protoreflect.FieldDescriptor
566-
cur := md
567-
for _, part := range parts {
561+
return lookupByPath(md, fieldPath, func(cur protoreflect.MessageDescriptor, part string) (protoreflect.FieldDescriptor, error) {
568562
num, err := strconv.Atoi(part)
569563
if err != nil {
570564
return nil, fmt.Errorf("invalid field number %q", part)
@@ -573,15 +567,40 @@ func lookupByNumberPath(md protoreflect.MessageDescriptor, fieldPath string) ([]
573567
if fd == nil {
574568
return nil, fmt.Errorf("%s does not contain field number %d", cur.FullName(), num)
575569
}
570+
return fd, nil
571+
})
572+
}
573+
574+
func lookupByNamePath(md protoreflect.MessageDescriptor, fieldPath string) ([]protoreflect.FieldDescriptor, error) {
575+
return lookupByPath(md, fieldPath, func(cur protoreflect.MessageDescriptor, part string) (protoreflect.FieldDescriptor, error) {
576+
fd := cur.Fields().ByName(protoreflect.Name(part))
577+
if fd == nil {
578+
return nil, fmt.Errorf("%s does not contain field '%s'", cur.FullName(), part)
579+
}
580+
return fd, nil
581+
})
582+
}
583+
584+
func lookupByPath(md protoreflect.MessageDescriptor, fieldPath string, lookup func(protoreflect.MessageDescriptor, string) (protoreflect.FieldDescriptor, error)) ([]protoreflect.FieldDescriptor, error) {
585+
if fieldPath == "" {
586+
return nil, fmt.Errorf("empty field path")
587+
}
588+
parts := strings.Split(fieldPath, ".")
589+
var fds []protoreflect.FieldDescriptor
590+
cur := md
591+
for i, part := range parts {
592+
fd, err := lookup(cur, part)
593+
if err != nil {
594+
return nil, err
595+
}
576596
fds = append(fds, fd)
577-
if fd.Kind() == protoreflect.MessageKind {
578-
cur = fd.Message()
579-
} else {
580-
cur = nil
597+
if i == len(parts)-1 {
598+
continue
581599
}
582-
if cur == nil && part != parts[len(parts)-1] {
600+
if fd.Kind() != protoreflect.MessageKind {
583601
return nil, fmt.Errorf("%s does not contain '%s'", md.FullName(), fieldPath)
584602
}
603+
cur = fd.Message()
585604
}
586605
return fds, nil
587606
}
@@ -598,11 +617,11 @@ func buildFieldReader(txn badgerd.Tx, resolver protodesc.Resolver, name protoref
598617
var key *keyField
599618
keyName, ok := protodb.KeyFieldName(md)
600619
if ok {
601-
fd := md.Fields().ByName(protoreflect.Name(keyName))
602-
if fd != nil {
620+
fds, err := lookupByNamePath(md, keyName)
621+
if err == nil {
603622
key = &keyField{
604-
name: fd.Name(),
605-
fd: fd,
623+
name: protoreflect.Name(fieldPathFromNames(fds)),
624+
fds: fds,
606625
dataPrefix: []byte(protodb.Data + "/" + string(name) + "/"),
607626
txn: txn,
608627
}

0 commit comments

Comments
 (0)