Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 76 additions & 15 deletions cmd/protoc-gen-protodb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type mod struct {
oneOfs map[string]struct{}
}

type keyPath []pgs.Field

func (p *mod) Name() string {
return "protodb"
}
Expand Down Expand Up @@ -85,28 +87,87 @@ func (p *mod) Execute(targets map[string]pgs.File, _ map[string]pgs.Package) []p
if !ok || !enabled {
continue
}
var keyField string
for _, f := range m.Fields() {
var k bool
f.Extension(protodb.E_Key, &k)
if !k {
continue
}
if keyField != "" {
p.Failf("%s: %s: key already defined for message: %s", m.Name(), f.Name(), keyField)
}
keyField = f.Name().String()
if f.Type().IsMap() || f.Type().IsRepeated() || f.InOneOf() || f.Descriptor().GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE {
p.Failf("%s: %s: only non repeated and not oneof scalar types are supported as key, got %T", m.Name(), f.Name(), f.Type())
}
}
p.validateMessageKey(m)
msgs = append(msgs, m)
}
p.generate(f, msgs)
}
return p.Artifacts()
}

func (p *mod) validateMessageKey(m pgs.Message) {
keys := p.collectMessageKeys(m)
if len(keys) == 0 {
return
}
for _, path := range keys {
leaf := path.leaf()
if leaf.Type().IsMap() || leaf.Type().IsRepeated() || leaf.InOneOf() || leaf.Descriptor().GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE {
p.Failf("%s: %s: only non repeated and not oneof scalar types are supported as key, got %T", m.Name(), path.String(), leaf.Type())
}
for i := 0; i < len(path)-1; i++ {
f := path[i]
if !f.Type().IsEmbed() || f.Type().IsMap() || f.Type().IsRepeated() || f.InOneOf() {
p.Failf("%s: %s: nested key path must traverse singular message fields", m.Name(), path.String())
}
}
}
if len(keys) == 1 {
return
}
paths := make([]string, 0, len(keys))
for _, path := range keys {
paths = append(paths, path.String())
}
p.Failf("%s: multiple keys found for message: %s", m.Name(), strings.Join(paths, ", "))
}

func (p *mod) collectMessageKeys(m pgs.Message) []keyPath {
stack := map[string]bool{m.FullyQualifiedName(): true}
return p.collectMessageKeysIn(m, nil, stack)
}

func (p *mod) collectMessageKeysIn(m pgs.Message, path keyPath, stack map[string]bool) []keyPath {
var out []keyPath
for _, f := range m.Fields() {
next := make(keyPath, 0, len(path)+1)
next = append(next, path...)
next = append(next, f)
var key bool
f.Extension(protodb.E_Key, &key)
if key {
out = append(out, next)
}
if !f.Type().IsEmbed() || f.Type().IsMap() || f.Type().IsRepeated() || f.InOneOf() {
continue
}
embed := f.Type().Embed()
if embed == nil {
continue
}
name := embed.FullyQualifiedName()
if stack[name] {
continue
}
stack[name] = true
out = append(out, p.collectMessageKeysIn(embed, next, stack)...)
delete(stack, name)
}
return out
}

func (p keyPath) leaf() pgs.Field {
return p[len(p)-1]
}

func (p keyPath) String() string {
parts := make([]string, 0, len(p))
for _, f := range p {
parts = append(parts, f.Name().String())
}
return strings.Join(parts, ".")
}

func (p *mod) generate(f pgs.File, msgs []pgs.Message) {
if len(msgs) == 0 {
return
Expand Down
4 changes: 2 additions & 2 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func Open(ctx context.Context, opts ...Option) (protodb.DB, error) {
}
db := &db{opts: o, reg: reg, smu: mutex.NewKV(), ctxmu: mutex.NewContextKV()}
db.matcher = pf.NewMatcher()
db.idx = idxstore.NewIndexer(db.reg, db.reg.Files, db.unmarshal)
db.idx = idxstore.NewIndexer(db.reg, db.unmarshal)
if o.repl != nil {
h, err := server.NewServer(db)
if err != nil {
Expand Down Expand Up @@ -206,7 +206,7 @@ func (db *db) Watch(ctx context.Context, m proto.Message, opts ...protodb.GetOpt
o := makeGetOpts(opts...)
matcher := db.matcher

k, _, _, _ := protodb.DataPrefix(m)
k, _, _, _ := db.reg.DataPrefix(m)
log := logger.C(ctx).WithFields("service", "protodb", "action", "watch", "key", string(k))
log.Debugf("start watching for key %s", string(k))
ch := make(chan protodb.Event, 1)
Expand Down
28 changes: 20 additions & 8 deletions internal/db/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (tx *tx) get(ctx context.Context, m proto.Message, opts ...protodb.GetOptio
} else if lim := o.Paging.GetLimit(); lim > 0 {
out = make([]proto.Message, 0, lim)
}
prefix, field, value, _ := protodb.DataPrefix(m)
prefix, field, value, _ := tx.db.reg.DataPrefix(m)
span.SetAttributes(
attribute.String("prefix", string(prefix)),
attribute.String("key_field", field),
Expand Down Expand Up @@ -499,7 +499,7 @@ func (tx *tx) set(ctx context.Context, m proto.Message, opts ...protodb.SetOptio
if o.TTL != 0 {
expiresAt = uint64(time.Now().Add(o.TTL).Unix())
}
k, field, value, err := protodb.DataPrefix(m)
k, field, value, err := tx.db.reg.DataPrefix(m)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -645,7 +645,7 @@ func (tx *tx) delete(ctx context.Context, m proto.Message) error {
return badger.ErrReadOnlyTxn
}
// TODO(adphi): should we check / read for key first ?
k, field, value, err := protodb.DataPrefix(m)
k, field, value, err := tx.db.reg.DataPrefix(m)
if err != nil {
return err
}
Expand Down Expand Up @@ -878,7 +878,7 @@ func (tx *tx) getOrdered(ctx context.Context, m proto.Message, o protodb.GetOpts
)
defer span.End()
}
plan, err := buildOrderPlan(m.ProtoReflect().New(), o.OrderBy)
plan, err := buildOrderPlanWithKeyField(m.ProtoReflect().New(), o.OrderBy, tx.db.reg.KeyFieldName)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -929,7 +929,7 @@ func (tx *tx) getOrderedIndexed(ctx context.Context, m proto.Message, o protodb.
span.SetAttributes(attribute.Bool("fallback", true), attribute.String("fallback_reason", "key_order"))
return nil, nil, false, nil
}
prefix, _, _, _ := protodb.DataPrefix(m)
prefix, _, _, _ := tx.db.reg.DataPrefix(m)
if o.Filter != nil {
ok, err := tx.db.idx.IndexableFilter(m, o.Filter)
if err != nil {
Expand Down Expand Up @@ -1082,7 +1082,7 @@ func (tx *tx) getOrderedFallbackScanSort(ctx context.Context, m proto.Message, o

ordered := make([]orderedResult, 0, len(all))
for _, item := range all {
key, _, _, err := protodb.DataPrefix(item)
key, _, _, err := tx.db.reg.DataPrefix(item)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -1158,11 +1158,15 @@ func (tx *tx) getOrderedFallbackScanSort(ctx context.Context, m proto.Message, o
}

func buildOrderPlan(msg protoreflect.Message, orderBy *v1alpha1.OrderBy) (orderField, error) {
return buildOrderPlanWithKeyField(msg, orderBy, protodb.KeyFieldName)
}

func buildOrderPlanWithKeyField(msg protoreflect.Message, orderBy *v1alpha1.OrderBy, keyFieldName func(protoreflect.MessageDescriptor) (string, bool)) (orderField, error) {
if orderBy == nil {
return orderField{}, errors.New("order_by cannot be empty")
}
md := msg.Descriptor()
keyField, hasKey := protodb.KeyFieldName(md)
keyField, hasKey := keyFieldName(md)
fieldPath := strings.TrimSpace(orderBy.GetField())
if fieldPath == "" {
return orderField{}, errors.New("order_by field cannot be empty")
Expand All @@ -1174,7 +1178,7 @@ func buildOrderPlan(msg protoreflect.Message, orderBy *v1alpha1.OrderBy) (orderF
if !isOrderableFieldPath(fds) {
return orderField{}, fmt.Errorf("order_by field %q is not sortable", fieldPath)
}
isKey := hasKey && len(fds) == 1 && string(fds[0].Name()) == keyField
isKey := hasKey && nameFieldPath(fds) == keyField
if !isKey {
fd := fds[len(fds)-1]
if !proto.HasExtension(fd.Options(), protopts.E_Index) {
Expand Down Expand Up @@ -1319,3 +1323,11 @@ func hash(f protodb.Filter) (hash string, err error) {
h := sha512.Sum512(b)
return base64.StdEncoding.EncodeToString(h[:]), nil
}

func nameFieldPath(fds []protoreflect.FieldDescriptor) string {
parts := make([]string, 0, len(fds))
for _, fd := range fds {
parts = append(parts, string(fd.Name()))
}
return strings.Join(parts, ".")
}
41 changes: 41 additions & 0 deletions internal/db/tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@ import (
"testing"

"github.com/dgraph-io/badger/v3"
"github.com/jhump/protoreflect/v2/protobuilder"
"github.com/stretchr/testify/require"
"go.linka.cloud/protofilters/filters"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"

"go.linka.cloud/protodb/internal/protodb"
"go.linka.cloud/protodb/internal/token"
protopts "go.linka.cloud/protodb/protodb"
v1alpha1 "go.linka.cloud/protodb/protodb/v1alpha1"
)

Expand Down Expand Up @@ -55,6 +60,18 @@ func TestBuildOrderPlan(t *testing.T) {
_, err := buildOrderPlan(m.ProtoReflect(), &v1alpha1.OrderBy{Field: "status", Direction: v1alpha1.OrderDirection(99)})
require.ErrorContains(t, err, "invalid direction")
})

require.NoError(t, d.RegisterProto(ctx, buildNestedKeyOrderFileDescriptor(t)))
nmd, err := lookupMessage(d, "tests.order.OrderDoc")
require.NoError(t, err)
nm := dynamicpb.NewMessage(nmd)

t.Run("nested key field", func(t *testing.T) {
pl, err := buildOrderPlan(nm.ProtoReflect(), &v1alpha1.OrderBy{Field: "metadata.id"})
require.NoError(t, err)
require.Equal(t, "metadata.id", pl.fieldPath)
require.Equal(t, v1alpha1.OrderDirectionAsc, pl.direction)
})
}

func TestGetContinuationValidation(t *testing.T) {
Expand Down Expand Up @@ -297,6 +314,30 @@ func lookupNumberField(md protoreflect.MessageDescriptor, n string) (protoreflec
return nil, fmt.Errorf("field number %s not found", n)
}

func buildNestedKeyOrderFileDescriptor(t *testing.T) *descriptorpb.FileDescriptorProto {
t.Helper()
keyOpts := &descriptorpb.FieldOptions{}
proto.SetExtension(keyOpts, protopts.E_Key, true)

meta := protobuilder.NewMessage("Metadata").
AddField(protobuilder.NewField("id", protobuilder.FieldTypeString()).SetNumber(1).SetOptions(keyOpts))

msg := protobuilder.NewMessage("OrderDoc").
AddField(protobuilder.NewField("metadata", protobuilder.FieldTypeMessage(meta)).SetNumber(1)).
AddField(protobuilder.NewField("status", protobuilder.FieldTypeString()).SetNumber(2)).
AddNestedMessage(meta)

file := protobuilder.NewFile("tests/order_nested_key.proto").
SetPackageName(protoreflect.FullName("tests.order")).
SetSyntax(protoreflect.Proto3).
AddMessage(msg).
AddImportedDependency(protopts.File_protodb_protodb_proto)

fd, err := file.Build()
require.NoError(t, err)
return protodesc.ToFileDescriptorProto(fd)
}

func TestGetWithMalformedTokenReturnsInvalid(t *testing.T) {
ctx := context.Background()
dbif, err := Open(ctx, WithPath(t.TempDir()), WithApplyDefaults(true))
Expand Down
Loading
Loading