diff --git a/cmd/protoc-gen-protodb/main.go b/cmd/protoc-gen-protodb/main.go index 536919f..e87e59f 100644 --- a/cmd/protoc-gen-protodb/main.go +++ b/cmd/protoc-gen-protodb/main.go @@ -42,6 +42,8 @@ type mod struct { oneOfs map[string]struct{} } +type keyPath []pgs.Field + func (p *mod) Name() string { return "protodb" } @@ -85,21 +87,7 @@ func (p *mod) Execute(targets map[string]pgs.File, _ map[string]pgs.Package) []p if !ok || !enabled { continue } - var keyField string - for _, f := range m.Fields() { - var k bool - f.Extension(protodb.E_Key, &k) - if !k { - continue - } - if keyField != "" { - p.Failf("%s: %s: key already defined for message: %s", m.Name(), f.Name(), keyField) - } - keyField = f.Name().String() - if f.Type().IsMap() || f.Type().IsRepeated() || f.InOneOf() || f.Descriptor().GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE { - p.Failf("%s: %s: only non repeated and not oneof scalar types are supported as key, got %T", m.Name(), f.Name(), f.Type()) - } - } + p.validateMessageKey(m) msgs = append(msgs, m) } p.generate(f, msgs) @@ -107,6 +95,79 @@ func (p *mod) Execute(targets map[string]pgs.File, _ map[string]pgs.Package) []p return p.Artifacts() } +func (p *mod) validateMessageKey(m pgs.Message) { + keys := p.collectMessageKeys(m) + if len(keys) == 0 { + return + } + for _, path := range keys { + leaf := path.leaf() + if leaf.Type().IsMap() || leaf.Type().IsRepeated() || leaf.InOneOf() || leaf.Descriptor().GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE { + p.Failf("%s: %s: only non repeated and not oneof scalar types are supported as key, got %T", m.Name(), path.String(), leaf.Type()) + } + for i := 0; i < len(path)-1; i++ { + f := path[i] + if !f.Type().IsEmbed() || f.Type().IsMap() || f.Type().IsRepeated() || f.InOneOf() { + p.Failf("%s: %s: nested key path must traverse singular message fields", m.Name(), path.String()) + } + } + } + if len(keys) == 1 { + return + } + paths := make([]string, 0, len(keys)) + for _, path := range keys { + paths = append(paths, path.String()) + } + p.Failf("%s: multiple keys found for message: %s", m.Name(), strings.Join(paths, ", ")) +} + +func (p *mod) collectMessageKeys(m pgs.Message) []keyPath { + stack := map[string]bool{m.FullyQualifiedName(): true} + return p.collectMessageKeysIn(m, nil, stack) +} + +func (p *mod) collectMessageKeysIn(m pgs.Message, path keyPath, stack map[string]bool) []keyPath { + var out []keyPath + for _, f := range m.Fields() { + next := make(keyPath, 0, len(path)+1) + next = append(next, path...) + next = append(next, f) + var key bool + f.Extension(protodb.E_Key, &key) + if key { + out = append(out, next) + } + if !f.Type().IsEmbed() || f.Type().IsMap() || f.Type().IsRepeated() || f.InOneOf() { + continue + } + embed := f.Type().Embed() + if embed == nil { + continue + } + name := embed.FullyQualifiedName() + if stack[name] { + continue + } + stack[name] = true + out = append(out, p.collectMessageKeysIn(embed, next, stack)...) + delete(stack, name) + } + return out +} + +func (p keyPath) leaf() pgs.Field { + return p[len(p)-1] +} + +func (p keyPath) String() string { + parts := make([]string, 0, len(p)) + for _, f := range p { + parts = append(parts, f.Name().String()) + } + return strings.Join(parts, ".") +} + func (p *mod) generate(f pgs.File, msgs []pgs.Message) { if len(msgs) == 0 { return diff --git a/internal/db/db.go b/internal/db/db.go index 66a4eba..c5b7ff9 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -77,7 +77,7 @@ func Open(ctx context.Context, opts ...Option) (protodb.DB, error) { } db := &db{opts: o, reg: reg, smu: mutex.NewKV(), ctxmu: mutex.NewContextKV()} db.matcher = pf.NewMatcher() - db.idx = idxstore.NewIndexer(db.reg, db.reg.Files, db.unmarshal) + db.idx = idxstore.NewIndexer(db.reg, db.unmarshal) if o.repl != nil { h, err := server.NewServer(db) if err != nil { @@ -206,7 +206,7 @@ func (db *db) Watch(ctx context.Context, m proto.Message, opts ...protodb.GetOpt o := makeGetOpts(opts...) matcher := db.matcher - k, _, _, _ := protodb.DataPrefix(m) + k, _, _, _ := db.reg.DataPrefix(m) log := logger.C(ctx).WithFields("service", "protodb", "action", "watch", "key", string(k)) log.Debugf("start watching for key %s", string(k)) ch := make(chan protodb.Event, 1) diff --git a/internal/db/tx.go b/internal/db/tx.go index 5ec9808..4516c10 100644 --- a/internal/db/tx.go +++ b/internal/db/tx.go @@ -144,7 +144,7 @@ func (tx *tx) get(ctx context.Context, m proto.Message, opts ...protodb.GetOptio } else if lim := o.Paging.GetLimit(); lim > 0 { out = make([]proto.Message, 0, lim) } - prefix, field, value, _ := protodb.DataPrefix(m) + prefix, field, value, _ := tx.db.reg.DataPrefix(m) span.SetAttributes( attribute.String("prefix", string(prefix)), attribute.String("key_field", field), @@ -499,7 +499,7 @@ func (tx *tx) set(ctx context.Context, m proto.Message, opts ...protodb.SetOptio if o.TTL != 0 { expiresAt = uint64(time.Now().Add(o.TTL).Unix()) } - k, field, value, err := protodb.DataPrefix(m) + k, field, value, err := tx.db.reg.DataPrefix(m) if err != nil { return nil, err } @@ -645,7 +645,7 @@ func (tx *tx) delete(ctx context.Context, m proto.Message) error { return badger.ErrReadOnlyTxn } // TODO(adphi): should we check / read for key first ? - k, field, value, err := protodb.DataPrefix(m) + k, field, value, err := tx.db.reg.DataPrefix(m) if err != nil { return err } @@ -878,7 +878,7 @@ func (tx *tx) getOrdered(ctx context.Context, m proto.Message, o protodb.GetOpts ) defer span.End() } - plan, err := buildOrderPlan(m.ProtoReflect().New(), o.OrderBy) + plan, err := buildOrderPlanWithKeyField(m.ProtoReflect().New(), o.OrderBy, tx.db.reg.KeyFieldName) if err != nil { return nil, nil, err } @@ -929,7 +929,7 @@ func (tx *tx) getOrderedIndexed(ctx context.Context, m proto.Message, o protodb. span.SetAttributes(attribute.Bool("fallback", true), attribute.String("fallback_reason", "key_order")) return nil, nil, false, nil } - prefix, _, _, _ := protodb.DataPrefix(m) + prefix, _, _, _ := tx.db.reg.DataPrefix(m) if o.Filter != nil { ok, err := tx.db.idx.IndexableFilter(m, o.Filter) if err != nil { @@ -1082,7 +1082,7 @@ func (tx *tx) getOrderedFallbackScanSort(ctx context.Context, m proto.Message, o ordered := make([]orderedResult, 0, len(all)) for _, item := range all { - key, _, _, err := protodb.DataPrefix(item) + key, _, _, err := tx.db.reg.DataPrefix(item) if err != nil { return nil, nil, err } @@ -1158,11 +1158,15 @@ func (tx *tx) getOrderedFallbackScanSort(ctx context.Context, m proto.Message, o } func buildOrderPlan(msg protoreflect.Message, orderBy *v1alpha1.OrderBy) (orderField, error) { + return buildOrderPlanWithKeyField(msg, orderBy, protodb.KeyFieldName) +} + +func buildOrderPlanWithKeyField(msg protoreflect.Message, orderBy *v1alpha1.OrderBy, keyFieldName func(protoreflect.MessageDescriptor) (string, bool)) (orderField, error) { if orderBy == nil { return orderField{}, errors.New("order_by cannot be empty") } md := msg.Descriptor() - keyField, hasKey := protodb.KeyFieldName(md) + keyField, hasKey := keyFieldName(md) fieldPath := strings.TrimSpace(orderBy.GetField()) if fieldPath == "" { return orderField{}, errors.New("order_by field cannot be empty") @@ -1174,7 +1178,7 @@ func buildOrderPlan(msg protoreflect.Message, orderBy *v1alpha1.OrderBy) (orderF if !isOrderableFieldPath(fds) { return orderField{}, fmt.Errorf("order_by field %q is not sortable", fieldPath) } - isKey := hasKey && len(fds) == 1 && string(fds[0].Name()) == keyField + isKey := hasKey && nameFieldPath(fds) == keyField if !isKey { fd := fds[len(fds)-1] if !proto.HasExtension(fd.Options(), protopts.E_Index) { @@ -1319,3 +1323,11 @@ func hash(f protodb.Filter) (hash string, err error) { h := sha512.Sum512(b) return base64.StdEncoding.EncodeToString(h[:]), nil } + +func nameFieldPath(fds []protoreflect.FieldDescriptor) string { + parts := make([]string, 0, len(fds)) + for _, fd := range fds { + parts = append(parts, string(fd.Name())) + } + return strings.Join(parts, ".") +} diff --git a/internal/db/tx_test.go b/internal/db/tx_test.go index 7c34fe1..2e3885a 100644 --- a/internal/db/tx_test.go +++ b/internal/db/tx_test.go @@ -6,13 +6,18 @@ import ( "testing" "github.com/dgraph-io/badger/v3" + "github.com/jhump/protoreflect/v2/protobuilder" "github.com/stretchr/testify/require" "go.linka.cloud/protofilters/filters" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/dynamicpb" "go.linka.cloud/protodb/internal/protodb" "go.linka.cloud/protodb/internal/token" + protopts "go.linka.cloud/protodb/protodb" v1alpha1 "go.linka.cloud/protodb/protodb/v1alpha1" ) @@ -55,6 +60,18 @@ func TestBuildOrderPlan(t *testing.T) { _, err := buildOrderPlan(m.ProtoReflect(), &v1alpha1.OrderBy{Field: "status", Direction: v1alpha1.OrderDirection(99)}) require.ErrorContains(t, err, "invalid direction") }) + + require.NoError(t, d.RegisterProto(ctx, buildNestedKeyOrderFileDescriptor(t))) + nmd, err := lookupMessage(d, "tests.order.OrderDoc") + require.NoError(t, err) + nm := dynamicpb.NewMessage(nmd) + + t.Run("nested key field", func(t *testing.T) { + pl, err := buildOrderPlan(nm.ProtoReflect(), &v1alpha1.OrderBy{Field: "metadata.id"}) + require.NoError(t, err) + require.Equal(t, "metadata.id", pl.fieldPath) + require.Equal(t, v1alpha1.OrderDirectionAsc, pl.direction) + }) } func TestGetContinuationValidation(t *testing.T) { @@ -297,6 +314,30 @@ func lookupNumberField(md protoreflect.MessageDescriptor, n string) (protoreflec return nil, fmt.Errorf("field number %s not found", n) } +func buildNestedKeyOrderFileDescriptor(t *testing.T) *descriptorpb.FileDescriptorProto { + t.Helper() + keyOpts := &descriptorpb.FieldOptions{} + proto.SetExtension(keyOpts, protopts.E_Key, true) + + meta := protobuilder.NewMessage("Metadata"). + AddField(protobuilder.NewField("id", protobuilder.FieldTypeString()).SetNumber(1).SetOptions(keyOpts)) + + msg := protobuilder.NewMessage("OrderDoc"). + AddField(protobuilder.NewField("metadata", protobuilder.FieldTypeMessage(meta)).SetNumber(1)). + AddField(protobuilder.NewField("status", protobuilder.FieldTypeString()).SetNumber(2)). + AddNestedMessage(meta) + + file := protobuilder.NewFile("tests/order_nested_key.proto"). + SetPackageName(protoreflect.FullName("tests.order")). + SetSyntax(protoreflect.Proto3). + AddMessage(msg). + AddImportedDependency(protopts.File_protodb_protodb_proto) + + fd, err := file.Build() + require.NoError(t, err) + return protodesc.ToFileDescriptorProto(fd) +} + func TestGetWithMalformedTokenReturnsInvalid(t *testing.T) { ctx := context.Background() dbif, err := Open(ctx, WithPath(t.TempDir()), WithApplyDefaults(true)) diff --git a/internal/index/indexer.go b/internal/index/indexer.go index 717850a..dac8587 100644 --- a/internal/index/indexer.go +++ b/internal/index/indexer.go @@ -34,29 +34,24 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/dynamicpb" "go.linka.cloud/protodb/internal/badgerd" "go.linka.cloud/protodb/internal/protodb" + "go.linka.cloud/protodb/internal/registry" protopts "go.linka.cloud/protodb/protodb" ) var idxTracer = otel.Tracer("protodb.indexer") -type FileRegistry interface { - RangeFiles(func(protoreflect.FileDescriptor) bool) -} - type Tx interface { Txn() badgerd.Tx UID(ctx context.Context, key []byte, inc bool) (uint64, bool, error) } type Indexer struct { - reg protodesc.Resolver - freg FileRegistry + reg *registry.Registry unmarshal func([]byte, proto.Message) error mu sync.RWMutex entries map[protoreflect.FullName]entryCache @@ -73,8 +68,8 @@ type entryCache struct { entries []string } -func NewIndexer(reg protodesc.Resolver, freg FileRegistry, unmarshal func([]byte, proto.Message) error) *Indexer { - return &Indexer{reg: reg, freg: freg, unmarshal: unmarshal} +func NewIndexer(reg *registry.Registry, unmarshal func([]byte, proto.Message) error) *Indexer { + return &Indexer{reg: reg, unmarshal: unmarshal} } func (idx *Indexer) RebuildIfNeeded(ctx context.Context, tx Tx, files ...protoreflect.FileDescriptor) error { @@ -88,7 +83,7 @@ func (idx *Indexer) Rebuild(ctx context.Context, tx Tx, files ...protoreflect.Fi defer span.End() msgs := collectMessagesFromFiles(files) if len(files) == 0 { - msgs = collectMessageDescriptors(idx.freg) + msgs = collectMessageDescriptors(idx.reg) } span.SetAttributes( attribute.Int("index.messages", len(msgs)), @@ -175,7 +170,7 @@ func (idx *Indexer) IndexableFilter(m proto.Message, f protodb.Filter) (bool, er if f == nil || f.Expr() == nil { return false, nil } - return isIndexableExpr(m.ProtoReflect().New(), f.Expr()) + return idx.isIndexableExpr(m.ProtoReflect().New(), f.Expr()) } func (idx *Indexer) EnforceUnique(ctx context.Context, tx Tx, m proto.Message, uid uint64) error { @@ -820,7 +815,7 @@ func collectIndexEntries(md protoreflect.MessageDescriptor) []string { walk(fd.Message(), prefix+fmt.Sprintf("%d", fd.Number())+".") continue } - if !proto.HasExtension(fd.Options(), protopts.E_Index) || !isIndexableField(fd) { + if !proto.HasExtension(fd.Options(), protopts.E_Index) || !IsIndexableLeaf(fd) { continue } entries = append(entries, fmt.Sprintf("%s%d|%s|%d", prefix, fd.Number(), fd.Kind(), fd.Cardinality())) @@ -831,12 +826,12 @@ func collectIndexEntries(md protoreflect.MessageDescriptor) []string { return entries } -func collectMessageDescriptors(freg FileRegistry) []protoreflect.MessageDescriptor { - if freg == nil { +func collectMessageDescriptors(reg *registry.Registry) []protoreflect.MessageDescriptor { + if reg == nil { return nil } var out []protoreflect.MessageDescriptor - freg.RangeFiles(func(fd protoreflect.FileDescriptor) bool { + reg.RangeFiles(func(fd protoreflect.FileDescriptor) bool { out = append(out, collectMessages(fd.Messages())...) return true }) @@ -937,24 +932,24 @@ func collectUniqueFromMsgList(val protoreflect.Value, fds ...protoreflect.FieldD return out, nil } -func isIndexableExpr(msg protoreflect.Message, expr *filters.Expression) (bool, error) { +func (idx *Indexer) isIndexableExpr(msg protoreflect.Message, expr *filters.Expression) (bool, error) { if expr == nil { return true, nil } if expr.Condition != nil { - ok, err := isIndexableFieldPath(msg, expr.Condition.GetField()) + ok, err := idx.isIndexableFieldPath(msg, expr.Condition.GetField()) if err != nil || !ok { return ok, err } } for _, v := range expr.AndExprs { - ok, err := isIndexableExpr(msg, v) + ok, err := idx.isIndexableExpr(msg, v) if err != nil || !ok { return ok, err } } for _, v := range expr.OrExprs { - ok, err := isIndexableExpr(msg, v) + ok, err := idx.isIndexableExpr(msg, v) if err != nil || !ok { return ok, err } @@ -962,7 +957,7 @@ func isIndexableExpr(msg protoreflect.Message, expr *filters.Expression) (bool, return true, nil } -func isIndexableFieldPath(msg protoreflect.Message, fieldPath string) (bool, error) { +func (idx *Indexer) isIndexableFieldPath(msg protoreflect.Message, fieldPath string) (bool, error) { if fieldPath == "" { return false, nil } @@ -974,24 +969,31 @@ func isIndexableFieldPath(msg protoreflect.Message, fieldPath string) (bool, err if proto.HasExtension(fd.Options(), protopts.E_Index) { return isIndexableFieldPathDescriptors(fds), nil } - if !isKeyField(msg.Descriptor(), fds) { + if !idx.isKeyField(msg.Descriptor(), fds) { return false, nil } return isIndexableFieldPathDescriptors(fds), nil } -func isKeyField(md protoreflect.MessageDescriptor, fds []protoreflect.FieldDescriptor) bool { +func (idx *Indexer) isKeyField(md protoreflect.MessageDescriptor, fds []protoreflect.FieldDescriptor) bool { if md == nil { return false } - if len(fds) != 1 { + if len(fds) == 0 { return false } - field, ok := protodb.KeyFieldName(md) + field, ok := keyName(md, idx.reg) if !ok { return false } - return field == string(fds[0].Name()) + return field == fieldPathFromNames(fds) +} + +func keyName(md protoreflect.MessageDescriptor, reg *registry.Registry) (string, bool) { + if reg == nil { + return protodb.KeyFieldName(md) + } + return reg.KeyFieldName(md) } func isIndexableFieldPathDescriptors(fds []protoreflect.FieldDescriptor) bool { @@ -1016,10 +1018,6 @@ func isIndexableFieldPathDescriptors(fds []protoreflect.FieldDescriptor) bool { return IsIndexableLeaf(last) } -func isIndexableField(fd protoreflect.FieldDescriptor) bool { - return IsIndexableLeaf(fd) -} - // IsIndexableLeaf reports whether a field descriptor supports index ordering/lookup. func IsIndexableLeaf(fd protoreflect.FieldDescriptor) bool { switch fd.Kind() { diff --git a/internal/index/indexer_test.go b/internal/index/indexer_test.go index 70957e0..f78e86d 100644 --- a/internal/index/indexer_test.go +++ b/internal/index/indexer_test.go @@ -12,17 +12,17 @@ import ( "go.linka.cloud/protofilters/index/bitmap" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/dynamicpb" "go.linka.cloud/protodb/internal/badgerd" "go.linka.cloud/protodb/internal/protodb" + regpkg "go.linka.cloud/protodb/internal/registry" protopts "go.linka.cloud/protodb/protodb" ) func TestCollectEntriesCacheAndRefresh(t *testing.T) { - idx := NewIndexer(nil, nil, nil) + idx := NewIndexer(nil, nil) md1 := buildIndexerDocDescriptorV1(t) md2 := buildIndexerDocDescriptorV2(t) @@ -38,7 +38,7 @@ func TestCollectEntriesCacheAndRefresh(t *testing.T) { } func TestIndexableFilter(t *testing.T) { - idx := NewIndexer(nil, nil, nil) + idx := NewIndexer(nil, nil) md := buildIndexerDocDescriptorV1(t) m := dynamicpb.NewMessage(md) @@ -52,6 +52,12 @@ func TestIndexableFilter(t *testing.T) { _, err = idx.IndexableFilter(m, filters.Where("meta.nope").StringEquals("x")) require.Error(t, err) + + nestedKey := buildIndexerNestedKeyDescriptor(t) + nestedMsg := dynamicpb.NewMessage(nestedKey) + ok, err = idx.IndexableFilter(nestedMsg, filters.Where("metadata.id").StringEquals("x")) + require.NoError(t, err) + require.True(t, ok) } func TestEnforceUnique(t *testing.T) { @@ -60,7 +66,7 @@ func TestEnforceUnique(t *testing.T) { require.NoError(t, err) defer db.Close() - idx := NewIndexer(nil, nil, nil) + idx := NewIndexer(nil, nil) md := buildUniqueDocDescriptor(t) emailFD := md.Fields().ByName("email") require.NotNil(t, emailFD) @@ -99,7 +105,7 @@ func TestEnforceUniqueDeduplicatesCollectedValues(t *testing.T) { require.NoError(t, err) defer db.Close() - idx := NewIndexer(nil, nil, nil) + idx := NewIndexer(nil, nil) md := buildUniqueTagsDescriptor(t) tagsFD := md.Fields().ByName("tags") require.NotNil(t, tagsFD) @@ -163,7 +169,7 @@ func TestOrderedUIDGroupsSeq(t *testing.T) { require.NoError(t, err) defer rtx.Close(ctx) - idx := NewIndexer(nil, nil, nil) + idx := NewIndexer(nil, nil) fds := []protoreflect.FieldDescriptor{statusFD} ascSeq, err := idx.OrderedUIDGroupsSeq(ctx, indexTx{tx: rtx}, md.FullName(), fds, false) @@ -212,7 +218,7 @@ func TestOrderedUIDGroupsSeqContextCanceled(t *testing.T) { cctx, cancel := context.WithCancel(context.Background()) cancel() - idx := NewIndexer(nil, nil, nil) + idx := NewIndexer(nil, nil) seq, err := idx.OrderedUIDGroupsSeq(cctx, indexTx{tx: rtx}, md.FullName(), []protoreflect.FieldDescriptor{statusFD}, false) require.NoError(t, err) for _, err := range seq { @@ -349,6 +355,32 @@ func buildUniqueDocDescriptor(t *testing.T) protoreflect.MessageDescriptor { return md } +func buildIndexerNestedKeyDescriptor(t *testing.T) protoreflect.MessageDescriptor { + t.Helper() + keyOpts := &descriptorpb.FieldOptions{} + proto.SetExtension(keyOpts, protopts.E_Key, true) + + meta := protobuilder.NewMessage("Metadata"). + AddField(protobuilder.NewField("id", protobuilder.FieldTypeString()).SetNumber(1).SetOptions(keyOpts)) + + msg := protobuilder.NewMessage("NestedKeyDoc"). + AddField(protobuilder.NewField("metadata", protobuilder.FieldTypeMessage(meta)).SetNumber(1)). + AddField(protobuilder.NewField("status", protobuilder.FieldTypeString()).SetNumber(2).SetOptions(&descriptorpb.FieldOptions{})). + AddNestedMessage(meta) + + file := protobuilder.NewFile("tests/indexer_nested_key.proto"). + SetPackageName(protoreflect.FullName("tests.index")). + SetSyntax(protoreflect.Proto3). + AddMessage(msg). + AddImportedDependency(protopts.File_protodb_protodb_proto) + + fd, err := file.Build() + require.NoError(t, err) + md := fd.Messages().ByName("NestedKeyDoc") + require.NotNil(t, md) + return md +} + func buildUniqueTagsDescriptor(t *testing.T) protoreflect.MessageDescriptor { t.Helper() keyOpts := &descriptorpb.FieldOptions{} @@ -373,7 +405,9 @@ func buildUniqueTagsDescriptor(t *testing.T) protoreflect.MessageDescriptor { } func TestCollectEntriesDeterministic(t *testing.T) { - idx := NewIndexer(&protoregistry.Files{}, nil, nil) + reg, err := regpkg.New() + require.NoError(t, err) + idx := NewIndexer(reg, nil) md := buildIndexerDocDescriptorV1(t) first := idx.CollectEntries(md) second := idx.CollectEntries(md) diff --git a/internal/index/store.go b/internal/index/store.go index 01a847f..bdb7a69 100644 --- a/internal/index/store.go +++ b/internal/index/store.go @@ -321,7 +321,7 @@ func (f *fieldReader) Get(_ context.Context, n protoreflect.Name) iter.Seq2[pfin type keyField struct { name protoreflect.Name - fd protoreflect.FieldDescriptor + fds []protoreflect.FieldDescriptor dataPrefix []byte txn badgerd.Tx entries []keyEntry @@ -371,7 +371,7 @@ func (k *keyField) iterate(yield func(pfindex.Field, error) bool) { } for _, e := range k.entries { entry := &field{ - descriptors: []protoreflect.FieldDescriptor{k.fd}, + descriptors: k.fds, value: protoreflect.ValueOfString(e.value), bitmap: e.bitmap, } @@ -558,13 +558,7 @@ func fieldPathFromNames(fds []protoreflect.FieldDescriptor) string { } func lookupByNumberPath(md protoreflect.MessageDescriptor, fieldPath string) ([]protoreflect.FieldDescriptor, error) { - parts := strings.Split(fieldPath, ".") - if len(parts) == 0 { - return nil, fmt.Errorf("empty field path") - } - var fds []protoreflect.FieldDescriptor - cur := md - for _, part := range parts { + return lookupByPath(md, fieldPath, func(cur protoreflect.MessageDescriptor, part string) (protoreflect.FieldDescriptor, error) { num, err := strconv.Atoi(part) if err != nil { return nil, fmt.Errorf("invalid field number %q", part) @@ -573,19 +567,56 @@ func lookupByNumberPath(md protoreflect.MessageDescriptor, fieldPath string) ([] if fd == nil { return nil, fmt.Errorf("%s does not contain field number %d", cur.FullName(), num) } + return fd, nil + }) +} + +func lookupByNamePath(md protoreflect.MessageDescriptor, fieldPath string) ([]protoreflect.FieldDescriptor, error) { + return lookupByPath(md, fieldPath, func(cur protoreflect.MessageDescriptor, part string) (protoreflect.FieldDescriptor, error) { + fd := cur.Fields().ByName(protoreflect.Name(part)) + if fd == nil { + return nil, fmt.Errorf("%s does not contain field '%s'", cur.FullName(), part) + } + return fd, nil + }) +} + +func lookupByPath(md protoreflect.MessageDescriptor, fieldPath string, lookup func(protoreflect.MessageDescriptor, string) (protoreflect.FieldDescriptor, error)) ([]protoreflect.FieldDescriptor, error) { + if fieldPath == "" { + return nil, fmt.Errorf("empty field path") + } + parts := strings.Split(fieldPath, ".") + var fds []protoreflect.FieldDescriptor + cur := md + for i, part := range parts { + fd, err := lookup(cur, part) + if err != nil { + return nil, err + } fds = append(fds, fd) - if fd.Kind() == protoreflect.MessageKind { - cur = fd.Message() - } else { - cur = nil + if i == len(parts)-1 { + continue } - if cur == nil && part != parts[len(parts)-1] { + if fd.Kind() != protoreflect.MessageKind { return nil, fmt.Errorf("%s does not contain '%s'", md.FullName(), fieldPath) } + cur = fd.Message() } return fds, nil } +type keyFieldNamer interface { + KeyFieldName(protoreflect.MessageDescriptor) (string, bool) +} + +func keyFieldName(resolver protodesc.Resolver, md protoreflect.MessageDescriptor) (string, bool) { + r, ok := resolver.(keyFieldNamer) + if ok { + return r.KeyFieldName(md) + } + return protodb.KeyFieldName(md) +} + func buildFieldReader(txn badgerd.Tx, resolver protodesc.Resolver, name protoreflect.FullName, precompute bool, owner *tx) (*fieldReader, error) { d, err := resolver.FindDescriptorByName(name) if err != nil { @@ -596,13 +627,13 @@ func buildFieldReader(txn badgerd.Tx, resolver protodesc.Resolver, name protoref return nil, fmt.Errorf("descriptor %s is not a message", name) } var key *keyField - keyName, ok := protodb.KeyFieldName(md) + keyName, ok := keyFieldName(resolver, md) if ok { - fd := md.Fields().ByName(protoreflect.Name(keyName)) - if fd != nil { + fds, err := lookupByNamePath(md, keyName) + if err == nil { key = &keyField{ - name: fd.Name(), - fd: fd, + name: protoreflect.Name(fieldPathFromNames(fds)), + fds: fds, dataPrefix: []byte(protodb.Data + "/" + string(name) + "/"), txn: txn, } diff --git a/internal/index/store_test.go b/internal/index/store_test.go index aa47f3c..b692630 100644 --- a/internal/index/store_test.go +++ b/internal/index/store_test.go @@ -27,10 +27,12 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" "go.linka.cloud/protodb/internal/badgerd" + protopts "go.linka.cloud/protodb/protodb" ) func TestValuePrefixRoundTrip(t *testing.T) { @@ -196,6 +198,29 @@ func TestBuildFieldReaderUsesInferredKeyField(t *testing.T) { require.NoError(t, rt.Close(ctx)) } +func TestBuildFieldReaderUsesNestedKeyField(t *testing.T) { + ctx := context.Background() + md := buildNestedKeyDescriptor(t) + + reg := &protoregistry.Files{} + require.NoError(t, reg.RegisterFile(md.ParentFile())) + + db, err := badgerd.Open(ctx, badgerd.WithInMemory(true)) + require.NoError(t, err) + defer db.Close() + + rt, err := db.NewTransaction(ctx, false) + require.NoError(t, err) + reader, err := buildFieldReader(rt, reg, md.FullName(), true, nil) + require.NoError(t, err) + require.NotNil(t, reader.key) + require.Equal(t, protoreflect.Name("metadata.id"), reader.key.name) + require.Len(t, reader.key.fds, 2) + require.Equal(t, protoreflect.Name("metadata"), reader.key.fds[0].Name()) + require.Equal(t, protoreflect.Name("id"), reader.key.fds[1].Name()) + require.NoError(t, rt.Close(ctx)) +} + func buildTypesDescriptor(t *testing.T) protoreflect.MessageDescriptor { keyEnum := protobuilder.NewEnum("Status"). AddValue(protobuilder.NewEnumValue("UNKNOWN").SetNumber(0)). @@ -261,6 +286,31 @@ func buildIDDescriptor(t *testing.T) protoreflect.MessageDescriptor { return md } +func buildNestedKeyDescriptor(t *testing.T) protoreflect.MessageDescriptor { + keyOpts := &descriptorpb.FieldOptions{} + proto.SetExtension(keyOpts, protopts.E_Key, true) + + meta := protobuilder.NewMessage("Metadata"). + AddField(protobuilder.NewField("id", protobuilder.FieldTypeString()).SetNumber(1).SetOptions(keyOpts)) + + msg := protobuilder.NewMessage("WithNestedKey"). + AddField(protobuilder.NewField("metadata", protobuilder.FieldTypeMessage(meta)).SetNumber(1)). + AddField(protobuilder.NewField("status", protobuilder.FieldTypeString()).SetNumber(2)). + AddNestedMessage(meta) + + file := protobuilder.NewFile("tests/with_nested_key.proto"). + SetPackageName(protoreflect.FullName("tests.index")). + SetSyntax(protoreflect.Proto3). + AddMessage(msg). + AddImportedDependency(protopts.File_protodb_protodb_proto) + + fd, err := file.Build() + require.NoError(t, err) + md := fd.Messages().ByName("WithNestedKey") + require.NotNil(t, md) + return md +} + func assertValueEqual(t *testing.T, fd protoreflect.FieldDescriptor, exp, got protoreflect.Value) { switch fd.Kind() { case protoreflect.FloatKind: diff --git a/internal/protodb/key.go b/internal/protodb/key.go index 0256a6b..3cbc140 100644 --- a/internal/protodb/key.go +++ b/internal/protodb/key.go @@ -37,33 +37,19 @@ func KeyFromOpts(m proto.Message) (key string, field string, ok bool) { if sk != nil && sk.(string) != "" { return sk.(string), "", true } - var kf protoreflect.FieldDescriptor - fields := m.ProtoReflect().Type().Descriptor().Fields() - for i := 0; i < fields.Len(); i++ { - f := fields.Get(i) - o, ok := f.Options().(*descriptorpb.FieldOptions) - if !ok { - continue - } - v := proto.GetExtension(o, protodb.E_Key) - if v == nil { - continue - } - b := v.(bool) - if b { - kf = f - break - } - } - if kf == nil { + path, ok := keyFieldPathFromOpts(m.ProtoReflect().Type().Descriptor()) + if !ok { return "", "", false } - field = string(kf.Name()) - v := m.ProtoReflect().Get(kf) - if !v.IsValid() { + field = keyPathFromNames(path) + if !isKeyLeaf(path[len(path)-1]) { + return "", field, true + } + v, ok := keyPathValue(m.ProtoReflect(), path) + if !ok || !v.IsValid() { return "", field, true } - if k := v.String(); k != "" && k != "0" { + if k, ok := keyString(v, path[len(path)-1]); ok { return k, field, true } return "", field, true @@ -280,6 +266,9 @@ func KeyFieldName(md protoreflect.MessageDescriptor) (string, bool) { if md == nil { return "", false } + if path, ok := keyFieldPathFromOpts(md); ok { + return keyPathFromNames(path), true + } m := dynamicpb.NewMessage(md) _, field, _ := KeyFor(m) if field != "" { @@ -306,6 +295,154 @@ func KeyFieldName(md protoreflect.MessageDescriptor) (string, bool) { return "", false } +func keyFieldPathFromOpts(md protoreflect.MessageDescriptor) ([]protoreflect.FieldDescriptor, bool) { + if md == nil { + return nil, false + } + stack := map[protoreflect.FullName]bool{md.FullName(): true} + return keyFieldPathFromOptsIn(md, nil, stack) +} + +func KeyFieldPathFromOpts(md protoreflect.MessageDescriptor) ([]protoreflect.FieldDescriptor, bool) { + return keyFieldPathFromOpts(md) +} + +func keyFieldPathFromOptsIn(md protoreflect.MessageDescriptor, path []protoreflect.FieldDescriptor, stack map[protoreflect.FullName]bool) ([]protoreflect.FieldDescriptor, bool) { + fields := md.Fields() + for i := 0; i < fields.Len(); i++ { + fd := fields.Get(i) + next := appendFieldPath(path, fd) + if isExplicitKey(fd) { + return next, true + } + if fd.Kind() != protoreflect.MessageKind || fd.IsList() || fd.IsMap() { + continue + } + child := fd.Message() + if child == nil || stack[child.FullName()] { + continue + } + stack[child.FullName()] = true + if nested, ok := keyFieldPathFromOptsIn(child, next, stack); ok { + return nested, true + } + delete(stack, child.FullName()) + } + return nil, false +} + +func appendFieldPath(path []protoreflect.FieldDescriptor, fd protoreflect.FieldDescriptor) []protoreflect.FieldDescriptor { + next := make([]protoreflect.FieldDescriptor, 0, len(path)+1) + next = append(next, path...) + return append(next, fd) +} + +func isExplicitKey(fd protoreflect.FieldDescriptor) bool { + if fd == nil { + return false + } + o, ok := fd.Options().(*descriptorpb.FieldOptions) + if !ok { + return false + } + v := proto.GetExtension(o, protodb.E_Key) + if v == nil { + return false + } + b, ok := v.(bool) + if !ok { + return false + } + return b +} + +func keyPathFromNames(path []protoreflect.FieldDescriptor) string { + parts := make([]string, 0, len(path)) + for _, fd := range path { + parts = append(parts, string(fd.Name())) + } + return strings.Join(parts, ".") +} + +func KeyPathFromNames(path []protoreflect.FieldDescriptor) string { + return keyPathFromNames(path) +} + +func keyPathValue(m protoreflect.Message, path []protoreflect.FieldDescriptor) (protoreflect.Value, bool) { + cur := m + for i, fd := range path { + v := cur.Get(fd) + if i == len(path)-1 { + return v, true + } + if fd.Kind() != protoreflect.MessageKind { + return protoreflect.Value{}, false + } + next := v.Message() + if !next.IsValid() { + return protoreflect.Value{}, false + } + cur = next + } + return protoreflect.Value{}, false +} + +func keyString(v protoreflect.Value, fd protoreflect.FieldDescriptor) (string, bool) { + if !v.IsValid() || fd == nil { + return "", false + } + if !hasKeyValue(v, fd) { + return "", false + } + return v.String(), true +} + +func KeyString(v protoreflect.Value, fd protoreflect.FieldDescriptor) (string, bool) { + return keyString(v, fd) +} + +func hasKeyValue(v protoreflect.Value, fd protoreflect.FieldDescriptor) bool { + switch fd.Kind() { + case protoreflect.BoolKind: + return v.Bool() + case protoreflect.EnumKind: + return v.Enum() != 0 + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, + protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + return v.Int() != 0 + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, + protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + return v.Uint() != 0 + case protoreflect.FloatKind, protoreflect.DoubleKind: + return v.Float() != 0 + case protoreflect.StringKind: + return v.String() != "" + case protoreflect.BytesKind: + return len(v.Bytes()) != 0 + default: + return false + } +} + +func isKeyLeaf(fd protoreflect.FieldDescriptor) bool { + if fd == nil { + return false + } + if fd.IsMap() || fd.IsList() { + return false + } + switch fd.Kind() { + case protoreflect.MessageKind, protoreflect.GroupKind: + return false + default: + return true + } +} + +func KeyLeaf(fd protoreflect.FieldDescriptor) bool { + return isKeyLeaf(fd) +} + func keyProbeValue(fd protoreflect.FieldDescriptor) (protoreflect.Value, bool) { if fd == nil { return protoreflect.Value{}, false diff --git a/internal/protodb/key_test.go b/internal/protodb/key_test.go new file mode 100644 index 0000000..82207bc --- /dev/null +++ b/internal/protodb/key_test.go @@ -0,0 +1,67 @@ +package protodb + +import ( + "testing" + + "github.com/jhump/protoreflect/v2/protobuilder" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" + + protopts "go.linka.cloud/protodb/protodb" +) + +func TestKeyForNestedKeyOption(t *testing.T) { + md := buildNestedKeyDescriptor(t) + + m := dynamicpb.NewMessage(md) + metaFD := md.Fields().ByName("metadata") + meta := dynamicpb.NewMessage(metaFD.Message()) + meta.Set(metaFD.Message().Fields().ByName("id"), protoreflect.ValueOfString("r1")) + m.Set(metaFD, protoreflect.ValueOfMessage(meta)) + + key, field, err := KeyFor(m) + require.NoError(t, err) + require.Equal(t, "r1", key) + require.Equal(t, "metadata.id", field) +} + +func TestKeyFieldNameNestedKeyOption(t *testing.T) { + md := buildNestedKeyDescriptor(t) + field, ok := KeyFieldName(md) + require.True(t, ok) + require.Equal(t, "metadata.id", field) +} + +func buildNestedKeyDescriptor(t *testing.T) protoreflect.MessageDescriptor { + t.Helper() + keyOpts := &descriptorpb.FieldOptions{} + proto.SetExtension(keyOpts, protopts.E_Key, true) + + meta := protobuilder.NewMessage("Metadata"). + AddField(protobuilder.NewField("id", protobuilder.FieldTypeString()).SetNumber(1).SetOptions(keyOpts)). + AddField(protobuilder.NewField("label", protobuilder.FieldTypeString()).SetNumber(2)) + + resource := protobuilder.NewMessage("Resource"). + AddField(protobuilder.NewField("metadata", protobuilder.FieldTypeMessage(meta)).SetNumber(1)). + AddField(protobuilder.NewField("field", protobuilder.FieldTypeString()).SetNumber(2)). + AddNestedMessage(meta) + + file := protobuilder.NewFile("tests/nested_key.proto"). + SetPackageName(protoreflect.FullName("tests.key")). + SetSyntax(protoreflect.Proto3). + AddMessage(resource). + AddImportedDependency(protopts.File_protodb_protodb_proto) + + fd, err := file.Build() + require.NoError(t, err) + protoFd := protodesc.ToFileDescriptorProto(fd) + require.NotNil(t, protoFd) + + md := fd.Messages().ByName("Resource") + require.NotNil(t, md) + return md +} diff --git a/internal/registry/init.go b/internal/registry/init.go index e49d1a2..dc8123a 100644 --- a/internal/registry/init.go +++ b/internal/registry/init.go @@ -15,6 +15,8 @@ package registry import ( + "sync" + "google.golang.org/protobuf/reflect/protoreflect" preg "google.golang.org/protobuf/reflect/protoregistry" ) @@ -22,6 +24,9 @@ import ( type Registry struct { *Files *Types + + mu sync.RWMutex + keys map[protoreflect.FullName]key } func New() (*Registry, error) { diff --git a/internal/registry/key.go b/internal/registry/key.go new file mode 100644 index 0000000..8a425b5 --- /dev/null +++ b/internal/registry/key.go @@ -0,0 +1,238 @@ +package registry + +import ( + "fmt" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + iprotodb "go.linka.cloud/protodb/internal/protodb" + protopts "go.linka.cloud/protodb/protodb" +) + +type key struct { + field string + path []protoreflect.FieldNumber + static string + leaf bool +} + +func (r *Registry) RegisterFile(file protoreflect.FileDescriptor) error { + if err := r.Files.RegisterFile(file); err != nil { + return err + } + r.dropFileKeys(file) + return nil +} + +func (r *Registry) RegisterMessage(mt protoreflect.MessageType) error { + if err := r.Types.RegisterMessage(mt); err != nil { + return err + } + r.dropKey(mt.Descriptor().FullName()) + return nil +} + +func (r *Registry) RegisterEnum(et protoreflect.EnumType) error { + return r.Types.RegisterEnum(et) +} + +func (r *Registry) RegisterExtension(xt protoreflect.ExtensionType) error { + return r.Types.RegisterExtension(xt) +} + +func (r *Registry) KeyFieldName(md protoreflect.MessageDescriptor) (string, bool) { + if r == nil { + return iprotodb.KeyFieldName(md) + } + k, ok := r.key(md) + if !ok { + return "", false + } + if k.field != "" { + return k.field, true + } + return iprotodb.KeyFieldName(md) +} + +func (r *Registry) KeyFor(m proto.Message) (key, field string, err error) { + if r == nil { + return iprotodb.KeyFor(m) + } + if m == nil { + return "", "", fmt.Errorf("key / id not found in ") + } + k, ok := r.key(m.ProtoReflect().Descriptor()) + if !ok { + return iprotodb.KeyFor(m) + } + if k.static != "" { + return k.static, "", nil + } + if len(k.path) == 0 { + return iprotodb.KeyFor(m) + } + field = k.field + if !k.leaf { + return "", field, keyNotFound(m) + } + v, fd, ok := keyValue(m.ProtoReflect(), k.path) + if !ok || !v.IsValid() { + return "", field, keyNotFound(m) + } + if key, ok := iprotodb.KeyString(v, fd); ok { + return key, field, nil + } + return "", field, keyNotFound(m) +} + +func (r *Registry) DataPrefix(m proto.Message) (key []byte, f string, value string, err error) { + if r == nil { + return iprotodb.DataPrefix(m) + } + k, f, err := r.KeyFor(m) + if err != nil { + return fmt.Appendf(nil, "%s/%s/", iprotodb.Data, m.ProtoReflect().Descriptor().FullName()), f, k, fmt.Errorf("key: %w", err) + } + return fmt.Appendf(nil, "%s/%s/%s", iprotodb.Data, m.ProtoReflect().Descriptor().FullName(), k), f, k, nil +} + +func (r *Registry) key(md protoreflect.MessageDescriptor) (key, bool) { + if r == nil { + return key{}, false + } + if md == nil { + return key{}, false + } + name := md.FullName() + r.mu.RLock() + if entry, ok := r.keys[name]; ok { + r.mu.RUnlock() + return entry, true + } + r.mu.RUnlock() + entry := buildKeyEntry(md) + r.mu.Lock() + if r.keys == nil { + r.keys = map[protoreflect.FullName]key{} + } + if cached, ok := r.keys[name]; ok { + r.mu.Unlock() + return cached, true + } + r.keys[name] = entry + r.mu.Unlock() + return entry, true +} + +func (r *Registry) dropKey(name protoreflect.FullName) { + r.mu.Lock() + delete(r.keys, name) + r.mu.Unlock() +} + +func (r *Registry) dropFileKeys(file protoreflect.FileDescriptor) { + r.mu.Lock() + if len(r.keys) == 0 { + r.mu.Unlock() + return + } + msgs := file.Messages() + for i := 0; i < msgs.Len(); i++ { + dropNestedKeys(r.keys, msgs.Get(i)) + } + r.mu.Unlock() +} + +func dropNestedKeys(cache map[protoreflect.FullName]key, md protoreflect.MessageDescriptor) { + delete(cache, md.FullName()) + msgs := md.Messages() + for i := 0; i < msgs.Len(); i++ { + dropNestedKeys(cache, msgs.Get(i)) + } +} + +func buildKeyEntry(md protoreflect.MessageDescriptor) key { + entry := key{} + if sk, ok := staticKey(md); ok { + entry.static = sk + return entry + } + if path, ok := keyPathFromOpts(md); ok { + entry.path = numPath(path) + entry.field = iprotodb.KeyPathFromNames(path) + entry.leaf = iprotodb.KeyLeaf(path[len(path)-1]) + return entry + } + if path, ok := inferKeyField(md); ok { + entry.path = []protoreflect.FieldNumber{path.Number()} + entry.field = string(path.Name()) + entry.leaf = iprotodb.KeyLeaf(path) + } + return entry +} + +func staticKey(md protoreflect.MessageDescriptor) (string, bool) { + v := proto.GetExtension(md.Options(), protopts.E_StaticKey) + if v == nil { + return "", false + } + sk, ok := v.(string) + if !ok || sk == "" { + return "", false + } + return sk, true +} + +func inferKeyField(md protoreflect.MessageDescriptor) (protoreflect.FieldDescriptor, bool) { + for _, name := range []protoreflect.Name{"id", "key", "name"} { + fd := md.Fields().ByName(name) + if fd == nil { + continue + } + if !iprotodb.KeyLeaf(fd) { + continue + } + return fd, true + } + return nil, false +} + +func keyPathFromOpts(md protoreflect.MessageDescriptor) ([]protoreflect.FieldDescriptor, bool) { + return iprotodb.KeyFieldPathFromOpts(md) +} + +func numPath(path []protoreflect.FieldDescriptor) []protoreflect.FieldNumber { + out := make([]protoreflect.FieldNumber, 0, len(path)) + for _, fd := range path { + out = append(out, fd.Number()) + } + return out +} + +func keyValue(m protoreflect.Message, path []protoreflect.FieldNumber) (protoreflect.Value, protoreflect.FieldDescriptor, bool) { + cur := m + for i, num := range path { + fd := cur.Descriptor().Fields().ByNumber(num) + if fd == nil { + return protoreflect.Value{}, nil, false + } + v := cur.Get(fd) + if i == len(path)-1 { + return v, fd, true + } + if fd.Kind() != protoreflect.MessageKind { + return protoreflect.Value{}, nil, false + } + next := v.Message() + if !next.IsValid() { + return protoreflect.Value{}, nil, false + } + cur = next + } + return protoreflect.Value{}, nil, false +} + +func keyNotFound(m proto.Message) error { + return fmt.Errorf("key / id not found in %s", m.ProtoReflect().Type().Descriptor().FullName()) +} diff --git a/internal/registry/key_test.go b/internal/registry/key_test.go new file mode 100644 index 0000000..a8685b1 --- /dev/null +++ b/internal/registry/key_test.go @@ -0,0 +1,136 @@ +package registry + +import ( + "testing" + + "github.com/jhump/protoreflect/v2/protobuilder" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" + + protopts "go.linka.cloud/protodb/protodb" +) + +func TestRegistryKeyCacheLazyAndInvalidate(t *testing.T) { + r := &Registry{Files: &Files{}, Types: &Types{}} + + fd1 := buildResourceFile(t, "acme/v1/resource.proto", true) + require.NoError(t, r.RegisterFile(fd1)) + require.Empty(t, r.keys) + + d1, err := r.FindDescriptorByName("acme.v1.Resource") + require.NoError(t, err) + md1 := d1.(protoreflect.MessageDescriptor) + + field, ok := r.KeyFieldName(md1) + require.True(t, ok) + require.Equal(t, "metadata.id", field) + require.Len(t, r.keys, 1) + + m1 := dynamicpb.NewMessage(md1) + metaFD1 := md1.Fields().ByName("metadata") + meta1 := dynamicpb.NewMessage(metaFD1.Message()) + meta1.Set(metaFD1.Message().Fields().ByName("id"), protoreflect.ValueOfString("r1")) + m1.Set(metaFD1, protoreflect.ValueOfMessage(meta1)) + + key, field, err := r.KeyFor(m1) + require.NoError(t, err) + require.Equal(t, "r1", key) + require.Equal(t, "metadata.id", field) + + fd2 := buildResourceFile(t, "acme/v1/resource.proto", false) + require.NoError(t, r.RegisterFile(fd2)) + require.Empty(t, r.keys) + + d2, err := r.FindDescriptorByName("acme.v1.Resource") + require.NoError(t, err) + md2 := d2.(protoreflect.MessageDescriptor) + + field, ok = r.KeyFieldName(md2) + require.True(t, ok) + require.Equal(t, "name", field) + require.Len(t, r.keys, 1) + + m2 := dynamicpb.NewMessage(md2) + m2.Set(md2.Fields().ByName("name"), protoreflect.ValueOfString("r2")) + + key, field, err = r.KeyFor(m2) + require.NoError(t, err) + require.Equal(t, "r2", key) + require.Equal(t, "name", field) +} + +func TestRegistryKeyFallbackPrefersIDOverName(t *testing.T) { + r := &Registry{Files: &Files{}, Types: &Types{}} + + fd := buildFallbackOrderFile(t) + require.NoError(t, r.RegisterFile(fd)) + + d, err := r.FindDescriptorByName("acme.v1.NamedResource") + require.NoError(t, err) + md := d.(protoreflect.MessageDescriptor) + + field, ok := r.KeyFieldName(md) + require.True(t, ok) + require.Equal(t, "id", field) + + m := dynamicpb.NewMessage(md) + m.Set(md.Fields().ByName("name"), protoreflect.ValueOfString("n1")) + m.Set(md.Fields().ByName("id"), protoreflect.ValueOfString("i1")) + + key, field, err := r.KeyFor(m) + require.NoError(t, err) + require.Equal(t, "i1", key) + require.Equal(t, "id", field) +} + +func buildResourceFile(t *testing.T, path string, nested bool) protoreflect.FileDescriptor { + t.Helper() + keyOpts := &descriptorpb.FieldOptions{} + proto.SetExtension(keyOpts, protopts.E_Key, true) + + id := protobuilder.NewField("id", protobuilder.FieldTypeString()).SetNumber(1) + if nested { + id = id.SetOptions(keyOpts) + } + meta := protobuilder.NewMessage("Metadata"). + AddField(id) + + name := protobuilder.NewField("name", protobuilder.FieldTypeString()).SetNumber(2) + if !nested { + name = name.SetOptions(keyOpts) + } + + resource := protobuilder.NewMessage("Resource"). + AddField(protobuilder.NewField("metadata", protobuilder.FieldTypeMessage(meta)).SetNumber(1)). + AddField(name). + AddNestedMessage(meta) + + file := protobuilder.NewFile(path). + SetPackageName(protoreflect.FullName("acme.v1")). + SetSyntax(protoreflect.Proto3). + AddMessage(resource). + AddImportedDependency(protopts.File_protodb_protodb_proto) + + fd, err := file.Build() + require.NoError(t, err) + return fd +} + +func buildFallbackOrderFile(t *testing.T) protoreflect.FileDescriptor { + t.Helper() + msg := protobuilder.NewMessage("NamedResource"). + AddField(protobuilder.NewField("name", protobuilder.FieldTypeString()).SetNumber(1)). + AddField(protobuilder.NewField("id", protobuilder.FieldTypeString()).SetNumber(2)) + + file := protobuilder.NewFile("acme/v1/named_resource.proto"). + SetPackageName(protoreflect.FullName("acme.v1")). + SetSyntax(protoreflect.Proto3). + AddMessage(msg) + + fd, err := file.Build() + require.NoError(t, err) + return fd +}