diff --git a/CHANGELOG.md b/CHANGELOG.md index a9ada00e540..d6d3e1c0037 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ adhere to [Semantic Versioning](https://semver.org) starting `v22.0.0`. - **Fixed** - fix(core): fix panic in verifyUniqueWithinMutation when mutation is conditionally pruned (#9450) - fix(query): return full float value in query results (#9492) +- **Vector** + - fix(vector/hnsw): correct early termination in bottom-layer search to ensure at least k + candidates are considered before breaking + - feat(vector/hnsw): add optional per-query controls to similar_to via a 4th argument: `ef` + (search breadth override) and `distance_threshold` (metric-domain cutoff); defaults unchanged ## [v24.X.X] - YYYY-MM-DD @@ -100,7 +105,7 @@ adhere to [Semantic Versioning](https://semver.org) starting `v22.0.0`. - **Perf** - perf(query): Read just the latest value for scalar types https://github.com/hypermodeinc/dgraph/pull/8966 - - perf(vector): Add heap to neighbour edges https://github.com/hypermodeinc/dgraph/pull/9122 + - perf(vector): Add heap to neighbor edges https://github.com/hypermodeinc/dgraph/pull/9122 ## [v24.0.1] - 2024-07-30 @@ -4798,8 +4803,8 @@ Users can set `port_offset` flag, to modify these fixed ports. - `Query` Grpc endpoint returns response in JSON under `Json` field instead of protocol buffer. `client.Unmarshal` method also goes away from the Go client. Users can use `json.Unmarshal` for unmarshalling the response. -- Response for predicate of type `geo` can be unmarshalled into a struct. Example - [here](https://godoc.org/github.com/hypermodeinc/dggraph/client#example-package--SetObject). +- Response for predicate of type `geo` can be unmarshalled into a struct. See the + [SetObject example](https://godoc.org/github.com/hypermodeinc/dggraph/client#example-package--SetObject). - `Node` and `Edge` structs go away along with the `SetValue...` methods. We recommend using [`SetJson`](https://godoc.org/github.com/hypermodeinc/dggraph/client#example-package--SetObject) and `DeleteJson` fields to do mutations. diff --git a/dql/parser.go b/dql/parser.go index b2c59743aef..b4101ddd1b9 100644 --- a/dql/parser.go +++ b/dql/parser.go @@ -1874,7 +1874,10 @@ L: case IsInequalityFn(function.Name): err = parseFuncArgs(it, function) - case function.Name == "uid_in" || function.Name == "similar_to": + case function.Name == "uid_in": + err = parseFuncArgs(it, function) + + case function.Name == "similar_to": err = parseFuncArgs(it, function) default: @@ -1892,6 +1895,25 @@ L: } expectArg = false continue + case itemLeftCurl: + // Guard: Only similar_to may use object-literal syntax in its 4th argument. + // By checking Name=="similar_to", Attr is set (predicate) and Args has + // exactly two elements (k and vector), we ensure the '{' is in position 4. + // All other functions receive the historical error for stray braces. + if function.Name != "similar_to" || function.Attr == "" || len(function.Args) != 2 { + return nil, itemInFunc.Errorf("Unrecognized character inside a func: U+007B '{'") + } + // Parse the object literal: {ef: 64, distance_threshold: 0.45} + // The helper consumes tokens until the matching '}' is found. + if err := parseSimilarToObjectArg(it, function, itemInFunc); err != nil { + return nil, err + } + expectArg = false + continue + case itemRightCurl: + // Right curly braces are never valid in function arguments outside of + // the object literal parsed above. Always error on stray '}'. + return nil, itemInFunc.Errorf("Unrecognized character inside a func: U+007D '}'") default: if itemInFunc.Typ != itemName { return nil, itemInFunc.Errorf("Expected arg after func [%s], but got item %v", @@ -2408,6 +2430,10 @@ loop: // The parentheses are balanced out. Let's break. break loop } + case item.Typ == itemLeftCurl: + return nil, item.Errorf("Unrecognized character inside a func: U+007B '{'") + case item.Typ == itemRightCurl: + return nil, item.Errorf("Unrecognized character inside a func: U+007D '}'") default: return nil, item.Errorf("Unexpected item while parsing @filter: %v", item) } @@ -3471,3 +3497,32 @@ func trySkipItemTyp(it *lex.ItemIterator, typ lex.ItemType) bool { it.Next() return true } + +func parseSimilarToObjectArg(it *lex.ItemIterator, function *Function, start lex.Item) error { + depth := 1 + var builder strings.Builder + builder.WriteString(start.Val) + + for depth > 0 { + if !it.Next() { + return start.Errorf("Unexpected end of object literal while parsing similar_to options") + } + + item := it.Item() + builder.WriteString(item.Val) + + switch item.Typ { + case itemLeftCurl: + depth++ + case itemRightCurl: + depth-- + case itemRightRound: + if depth > 0 { + return item.Errorf("Expected '}' before ')' in similar_to options") + } + } + } + + function.Args = append(function.Args, Arg{Value: builder.String()}) + return nil +} diff --git a/dql/parser_test.go b/dql/parser_test.go index d764fe9060e..e7c7477d93d 100644 --- a/dql/parser_test.go +++ b/dql/parser_test.go @@ -2518,6 +2518,12 @@ func TestParseFilter_brac(t *testing.T) { } // Test if unbalanced brac will lead to errors. +// Note: This query has two errors: missing ')' after '()' AND a stray '{'. +// After changes to support similar_to's JSON args the lexer now emits brace tokens +// instead of erroring immediately. This causes the query to fail on the structural +// error (unclosed brackets) rather than the character-specific error. This is an +// acceptable trade-off because queries with multiple syntax errors may report a different +// (but equally fatal) error first. func TestParseFilter_unbalancedbrac(t *testing.T) { query := ` query { @@ -2532,8 +2538,76 @@ func TestParseFilter_unbalancedbrac(t *testing.T) { ` _, err := Parse(Request{Str: query}) require.Error(t, err) - require.Contains(t, err.Error(), - "Unrecognized character inside a func: U+007B '{'") + require.Contains(t, err.Error(), "Unclosed Brackets") +} + +func TestParseSimilarToObjectLiteral(t *testing.T) { + query := `{ + q(func: similar_to(voptions, 4, "[0,0]", {distance_threshold: 1.5, ef: 12})) { + uid + } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + require.Len(t, res.Query, 1) + require.NotNil(t, res.Query[0]) + require.NotNil(t, res.Query[0].Func) + require.Equal(t, "similar_to", res.Query[0].Func.Name) + require.Len(t, res.Query[0].Func.Args, 3) + require.Equal(t, "voptions", res.Query[0].Func.Attr) + require.Equal(t, "4", res.Query[0].Func.Args[0].Value) + require.Equal(t, "[0,0]", res.Query[0].Func.Args[1].Value) + require.Equal(t, "{distance_threshold:1.5,ef:12}", res.Query[0].Func.Args[2].Value) +} + +func TestParseSimilarToStringOptions(t *testing.T) { + // Test string-based options format (backwards compatibility) + query := `{ + q(func: similar_to(voptions, 4, "[0,0]", "ef=64,distance_threshold=0.45")) { + uid + } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + require.Equal(t, "similar_to", res.Query[0].Func.Name) + require.Equal(t, "ef=64,distance_threshold=0.45", res.Query[0].Func.Args[2].Value) +} + +func TestParseSimilarToThreeArgs(t *testing.T) { + // Test three-arg form (no options) + query := `{ + q(func: similar_to(voptions, 4, "[0,0]")) { + uid + } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + require.Equal(t, "similar_to", res.Query[0].Func.Name) + require.Len(t, res.Query[0].Func.Args, 2) +} + +func TestParseSimilarToBraceInWrongPosition(t *testing.T) { + // Brace as third argument should be rejected + query := `{ + q(func: similar_to(voptions, 4, {ef: 12})) { + uid + } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Unrecognized character inside a func: U+007B '{'") +} + +func TestParseNonSimilarToWithBrace(t *testing.T) { + // Braces in non-similar_to functions should be rejected + query := `{ + q(func: eq(name, {value: "test"})) { + uid + } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Unrecognized character inside a func: U+007B '{'") } func TestParseFilter_Geo1(t *testing.T) { @@ -2768,6 +2842,10 @@ func TestParseCountAsFunc(t *testing.T) { } +// Note: This query has two errors: missing ')' after 'friends' AND a stray '}'. +// After changes to support similar_to's JSON args the lexer emits brace tokens instead +// of erroring immediately -- causing this to fail on unclosed brackets rather than the +// specific character error. See TestParseFilter_unbalancedbrac for full explanation. func TestParseCountError1(t *testing.T) { query := `{ me(func: uid(1)) { @@ -2779,10 +2857,11 @@ func TestParseCountError1(t *testing.T) { ` _, err := Parse(Request{Str: query}) require.Error(t, err) - require.Contains(t, err.Error(), - "Unrecognized character inside a func: U+007D '}'") + require.Contains(t, err.Error(), "Unclosed Brackets") } +// Note: Similar to TestParseCountError1, this has missing ')' and stray '}', +// now reports structural error instead of character-specific error. func TestParseCountError2(t *testing.T) { query := `{ me(func: uid(1)) { @@ -2794,8 +2873,7 @@ func TestParseCountError2(t *testing.T) { ` _, err := Parse(Request{Str: query}) require.Error(t, err) - require.Contains(t, err.Error(), - "Unrecognized character inside a func: U+007D '}'") + require.Contains(t, err.Error(), "Unclosed Brackets") } func TestParseCheckPwd(t *testing.T) { diff --git a/dql/state.go b/dql/state.go index 2b17fa8315a..3b4fef3788f 100644 --- a/dql/state.go +++ b/dql/state.go @@ -306,6 +306,19 @@ func lexFuncOrArg(l *lex.Lexer) lex.StateFn { l.Emit(itemLeftSquare) case r == rightSquare: l.Emit(itemRightSquare) + case r == leftCurl: + empty = false + l.Emit(itemLeftCurl) + // Design decision: Emit brace tokens without affecting ArgDepth tracking. + // This allows similar_to's JSON-style options ({ef: 64, distance_threshold: 0.45}) + // to be parsed. The parser validates whether braces are legal in context. + // Trade-off: Queries with multiple syntax errors (e.g., missing ')' AND stray '}') + // will report structural errors (Unclosed Brackets) rather than character-specific + // errors. This is acceptable as the query is still rejected with a clear error. + case r == rightCurl: + l.Emit(itemRightCurl) + // Don't decrement ArgDepth for braces; let parser validate context. + // See leftCurl case above for full rationale. case r == '#': return lexComment case r == '.': diff --git a/query/vector/vector_test.go b/query/vector/vector_test.go index 80616fe6e14..35a2b05cc3e 100644 --- a/query/vector/vector_test.go +++ b/query/vector/vector_test.go @@ -417,6 +417,74 @@ func TestVectorIndexRebuildWhenChange(t *testing.T) { require.Greater(t, dur, time.Second*4) } +func TestSimilarToOptionsIntegration(t *testing.T) { + const pred = "voptions" + dropPredicate(pred) + t.Cleanup(func() { dropPredicate(pred) }) + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, pred, "4", "euclidean")) + + rdf := `<0x1> "[0,0]" . + <0x2> "[1,0]" . + <0x3> "[2,0]" . + <0x4> "[5,0]" .` + require.NoError(t, addTriplesToCluster(rdf)) + + t.Run("ef_override_string_syntax", func(t *testing.T) { + query := `{ + results(func: similar_to(voptions, 3, "[0,0]", "ef=2")) { + uid + } + }` + resp := processQueryNoErr(t, query) + + var result struct { + Data struct { + Results []struct { + UID string `json:"uid"` + } `json:"results"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(resp), &result)) + require.Len(t, result.Data.Results, 3) + + expected := map[string]struct{}{"0x1": {}, "0x2": {}, "0x3": {}} + for _, r := range result.Data.Results { + _, ok := expected[r.UID] + require.Truef(t, ok, "unexpected uid %s", r.UID) + delete(expected, r.UID) + } + require.Empty(t, expected) + }) + + t.Run("distance_threshold_json_syntax", func(t *testing.T) { + query := `{ + results(func: similar_to(voptions, 4, "[0,0]", {distance_threshold: 1.5})) { + uid + } + }` + resp := processQueryNoErr(t, query) + + var result struct { + Data struct { + Results []struct { + UID string `json:"uid"` + } `json:"results"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(resp), &result)) + require.Len(t, result.Data.Results, 2) + + expected := map[string]struct{}{"0x1": {}, "0x2": {}} + for _, r := range result.Data.Results { + _, ok := expected[r.UID] + require.Truef(t, ok, "unexpected uid %s", r.UID) + delete(expected, r.UID) + } + require.Empty(t, expected) + }) +} + func TestVectorInQueryArgument(t *testing.T) { dropPredicate("vtest") setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "euclidean")) diff --git a/tok/hnsw/ef_recall_test.go b/tok/hnsw/ef_recall_test.go new file mode 100644 index 00000000000..50d1269b8c2 --- /dev/null +++ b/tok/hnsw/ef_recall_test.go @@ -0,0 +1,208 @@ +/* + * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package hnsw + +import ( + "context" + "encoding/binary" + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/hypermodeinc/dgraph/v25/tok/index" + opt "github.com/hypermodeinc/dgraph/v25/tok/options" + "github.com/hypermodeinc/dgraph/v25/x" +) + +// memoryCache satisfies index.CacheType for synthetic tests. +type memoryCache struct { + data map[string][]byte +} + +func (m *memoryCache) Get(key []byte) ([]byte, error) { + if val, ok := m.data[string(key)]; ok { + return val, nil + } + return nil, nil +} + +func (m *memoryCache) Ts() uint64 { return 0 } + +func (m *memoryCache) Find([]byte, func([]byte) bool) (uint64, error) { return 0, nil } + +func float64ArrayAsBytes(v []float64) []byte { + buf := make([]byte, 8*len(v)) + for i, f := range v { + binary.LittleEndian.PutUint64(buf[i*8:], math.Float64bits(f)) + } + return buf +} + +// Test that EfOverride widens the bottom-layer candidate set and improves recall on a tiny graph. +func TestHNSWSearchEfOverrideImprovesRecall(t *testing.T) { + ctx := context.Background() + + factory := CreateFactory[float64](64) + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 2) + options.SetOpt(EfSearchOpt, 1) + options.SetOpt(MetricOpt, GetSimType[float64](Euclidean, 64)) + + predName := "joefix_pred" + predWithNamespace := x.NamespaceAttr(x.RootNamespace, predName) + + rawIdx, err := factory.Create(predWithNamespace, options, 64) + require.NoError(t, err) + + // Use concrete type directly (same package) to set up a tiny synthetic graph. + ph, ok := rawIdx.(*persistentHNSW[float64]) + require.True(t, ok) + require.Equal(t, predWithNamespace, ph.pred) + + // Populate vectors in memory via cache data map keyed by DataKey. + vectors := map[uint64][]float64{ + 1: {0, 0, 10, 0}, // entry + 100: {0, 0, 0.1, 0}, // true nearest to query + 200: {0, 0, 3, 0}, // local minimum path + 201: {0, 0, 3.2, 0}, + } + + data := make(map[string][]byte) + for uid, vec := range vectors { + key := string(DataKey(ph.pred, uid)) + data[key] = float64ArrayAsBytes(vec) + } + + // Set entry pointer to uid 1. + entryKey := string(DataKey(ph.vecEntryKey, 1)) + data[entryKey] = Uint64ToBytes(1) + + // Wire a small graph that requires wider search to find uid 100 from entry 1. + ph.nodeAllEdges[1] = [][]uint64{{}, {200, 201}} + ph.nodeAllEdges[200] = [][]uint64{{1}, {1}} + ph.nodeAllEdges[201] = [][]uint64{{1}, {100}} + ph.nodeAllEdges[100] = [][]uint64{{201}, {201}} + + cache := &memoryCache{data: data} + + // Narrow ef behaves like legacy path: returns uid 200 for k=1. + narrow, err := ph.SearchWithOptions(ctx, cache, []float64{0, 0, 0.12, 0}, 1, index.VectorIndexOptions[float64]{}) + require.NoError(t, err) + require.Equal(t, []uint64{200}, narrow) + + // Wider ef surfaces the closer neighbor uid 100. + wide, err := ph.SearchWithOptions(ctx, cache, []float64{0, 0, 0.12, 0}, 1, index.VectorIndexOptions[float64]{EfOverride: 4}) + require.NoError(t, err) + require.Equal(t, []uint64{100}, wide) +} + +// Test Euclidean distance_threshold filters out results with squared distance above threshold. +func TestHNSWDistanceThreshold_Euclidean(t *testing.T) { + ctx := context.Background() + + factory := CreateFactory[float64](64) + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 1) + options.SetOpt(EfSearchOpt, 10) + options.SetOpt(MetricOpt, GetSimType[float64](Euclidean, 64)) + + pred := x.NamespaceAttr(x.RootNamespace, "thresh_pred_e") + rawIdx, err := factory.Create(pred, options, 64) + require.NoError(t, err) + ph := rawIdx.(*persistentHNSW[float64]) + + // Two vectors at known Euclidean distances from query. + // query q = (0,0), a=(0.6,0), b=(0.8,0) + // dist(q,a)=0.6, dist(q,b)=0.8 + data := map[string][]byte{ + string(DataKey(pred, 1)): float64ArrayAsBytes([]float64{0.6, 0}), + string(DataKey(pred, 2)): float64ArrayAsBytes([]float64{0.8, 0}), + string(DataKey(ph.vecEntryKey, 1)): Uint64ToBytes(1), + } + // Single-layer edges; ensure both are reachable from entry. + ph.nodeAllEdges[1] = [][]uint64{{1, 2}} + ph.nodeAllEdges[2] = [][]uint64{{1}} + + cache := &memoryCache{data: data} + q := []float64{0, 0} + + // With current internal Euclidean values, use thresholds directly in the metric domain. + // threshold 0.75: include uid 1 (0.6) and exclude uid 2 (0.8). + th := 0.75 + res, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &th, + EfOverride: 10, + }) + require.NoError(t, err) + require.Equal(t, []uint64{1}, res) + + // threshold 1.5 (>1) should include both neighbors and demonstrate we keep working in raw distance. + thHigh := 1.5 + resHigh, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &thHigh, + EfOverride: 10, + }) + require.NoError(t, err) + require.ElementsMatch(t, []uint64{1, 2}, resHigh) + + // threshold 0.6 includes uid 1 (distance 0.6) but excludes the rest (inclusive comparison). + thExact := 0.6 + resExact, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &thExact, + EfOverride: 10, + }) + require.NoError(t, err) + require.Equal(t, []uint64{1}, resExact) + + // threshold 0.5 should filter everything. + thLow := 0.5 + resLow, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &thLow, + EfOverride: 10, + }) + require.NoError(t, err) + require.Empty(t, resLow) +} + +// Test Cosine distance_threshold uses distance d = 1 - cosine_similarity. +func TestHNSWDistanceThreshold_Cosine(t *testing.T) { + ctx := context.Background() + + factory := CreateFactory[float64](64) + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 1) + options.SetOpt(EfSearchOpt, 10) + options.SetOpt(MetricOpt, GetSimType[float64](Cosine, 64)) + + pred := x.NamespaceAttr(x.RootNamespace, "thresh_pred_c") + rawIdx, err := factory.Create(pred, options, 64) + require.NoError(t, err) + ph := rawIdx.(*persistentHNSW[float64]) + + // Query q is unit along x-axis. + // a is exact match (cos sim 1.0, distance 0.0) + // b is 36.87 degrees (~cos 0.8, distance 0.2) + data := map[string][]byte{ + string(DataKey(pred, 1)): float64ArrayAsBytes([]float64{1, 0}), + string(DataKey(pred, 2)): float64ArrayAsBytes([]float64{0.8, 0.6}), + string(DataKey(ph.vecEntryKey, 1)): Uint64ToBytes(1), + } + ph.nodeAllEdges[1] = [][]uint64{{1, 2}} + ph.nodeAllEdges[2] = [][]uint64{{1}} + + cache := &memoryCache{data: data} + q := []float64{1, 0} + + // distance_threshold=0.1 should include uid 1 but exclude uid 2 (0.2 > 0.1) + th := 0.1 + res, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &th, + EfOverride: 10, + }) + require.NoError(t, err) + require.Equal(t, []uint64{1}, res) +} diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index e13ddddaf89..09052677e04 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -29,7 +29,7 @@ type persistentHNSW[T c.Float] struct { simType SimilarityType[T] floatBits int // nodeAllEdges[65443][1][3] indicates the 3rd neighbor in the first - // layer for uuid 65443. The result will be a neighboring uuid. + // layer for UUID 65443. The result will be a neighboring UUID. nodeAllEdges map[uint64][][]uint64 deadNodes map[uint64]struct{} } @@ -145,7 +145,7 @@ func (ph *persistentHNSW[T]) fillNeighborEdges(uuid uint64, c index.CacheType, e return true, nil } -// searchPersistentLayer searches a layer of the hnsw graph for the nearest +// searchPersistentLayer searches a layer of the HNSW graph for the nearest // neighbors of the query vector and returns the traversal path and the nearest // neighbors func (ph *persistentHNSW[T]) searchPersistentLayer( @@ -177,16 +177,13 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( //create set using map to append to on future visited nodes for candidateHeap.Len() != 0 { currCandidate := candidateHeap.Pop().(minPersistentHeapElement[T]) - if r.numNeighbors() < expectedNeighbors && + if r.numNeighbors() >= expectedNeighbors && ph.simType.isBetterScore(r.lastNeighborScore(), currCandidate.value) { - // If the "worst score" in our neighbors list is deemed to have - // a better score than the current candidate -- and if we have at - // least our expected number of nearest results -- we discontinue - // the search. - // Note that while this is faithful to the published - // HNSW algorithms insofar as we stop when we reach a local - // minimum, it leaves something to be desired in terms of - // guarantees of getting best results. + // Standard HNSW termination: once the current best candidate + // cannot improve the ef-sized neighbor set (and we already have + // at least expectedNeighbors), we stop exploring this layer. + // Recall is governed by ef; callers may raise ef (per‑query + // override supported) to explore further. break } @@ -246,7 +243,7 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( return r, nil } -// Search searches the hnsw graph for the nearest neighbors of the query vector +// Search searches the HNSW graph for the nearest neighbors of the query vector // and returns the traversal path and the nearest neighbors func (ph *persistentHNSW[T]) Search(ctx context.Context, c index.CacheType, query []T, maxResults int, filter index.SearchFilter[T]) (nnUids []uint64, err error) { @@ -254,7 +251,173 @@ func (ph *persistentHNSW[T]) Search(ctx context.Context, c index.CacheType, quer return r.Neighbors, err } -// SearchWithUid searches the hnsw graph for the nearest neighbors of the query uid +// SearchWithOptions applies optional per-call controls (ef override and distance threshold). +// When EfOverride > 0, it is applied at upper layers and the bottom layer uses +// candidateK = max(maxResults, EfOverride). Results return the best maxResults. +// When DistanceThreshold is set, results exceeding the threshold (in the metric domain) +// are filtered out before limiting to maxResults. +func (ph *persistentHNSW[T]) SearchWithOptions( + ctx context.Context, + c index.CacheType, + query []T, + maxResults int, + opts index.VectorIndexOptions[T], +) ([]uint64, error) { + if opts.Filter == nil { + opts.Filter = index.AcceptAll[T] + } + if maxResults < 0 { + maxResults = 0 + } + r := index.NewSearchPathResult() + start := time.Now().UnixMilli() + + // 0-profile_vector_entry + var startVec []T + entry, err := ph.PickStartNode(ctx, c, &startVec) + if err != nil { + return nil, err + } + + // Upper layers use efUpper (override if provided) + efUpper := ph.efSearch + if opts.EfOverride > 0 { + efUpper = opts.EfOverride + } + + for level := range ph.maxLevels - 1 { + if isEqual(startVec, query) { + break + } + filterOut := !opts.Filter(query, startVec, entry) + layerResult, err := ph.searchPersistentLayer( + c, level, entry, startVec, query, filterOut, efUpper, opts.Filter) + if err != nil { + return nil, err + } + layerResult.updateFinalMetrics(r) + entry = layerResult.bestNeighbor().index + layerResult.updateFinalPath(r) + if err = ph.getVecFromUid(entry, c, &startVec); err != nil { + return nil, err + } + } + + // Bottom layer: candidate size = max(k, efUpper) + filterOut := !opts.Filter(query, startVec, entry) + candidateK := maxResults + if efUpper > candidateK { + candidateK = efUpper + } + layerResult, err := ph.searchPersistentLayer( + c, ph.maxLevels-1, entry, startVec, query, filterOut, candidateK, opts.Filter) + if err != nil { + return nil, err + } + layerResult.updateFinalMetrics(r) + layerResult.updateFinalPath(r) + + // Build final neighbor list with optional threshold, limited to maxResults. + res := make([]uint64, 0, maxResults) + for _, n := range layerResult.neighbors { + if maxResults == 0 { + break + } + if n.filteredOut { + continue + } + if opts.DistanceThreshold != nil { + th := *opts.DistanceThreshold + switch ph.simType.indexType { + case Euclidean: + // n.value stores the metric-domain distance (not squared). + if float64(n.value) > th { + continue + } + case Cosine: + // n.value is cosine similarity in [-1,1]; cosine distance d = 1 - sim must be <= th. + if float64(1.0)-float64(n.value) > th { + continue + } + default: + // Dot product or others: ignore threshold for now. + } + } + res = append(res, n.index) + if len(res) >= maxResults { + break + } + } + + r.Metrics[searchTime] = uint64(time.Now().UnixMilli() - start) + return res, nil +} + +// SearchWithUidAndOptions is analogous to SearchWithUid but applies per‑call options. +func (ph *persistentHNSW[T]) SearchWithUidAndOptions( + _ context.Context, + c index.CacheType, + queryUid uint64, + maxResults int, + opts index.VectorIndexOptions[T], +) ([]uint64, error) { + if opts.Filter == nil { + opts.Filter = index.AcceptAll[T] + } + if maxResults < 0 { + maxResults = 0 + } + var queryVec []T + if err := ph.getVecFromUid(queryUid, c, &queryVec); err != nil { + if errors.Is(err, errFetchingPostingList) { + return []uint64{}, nil + } + return []uint64{}, err + } + if len(queryVec) == 0 { + return []uint64{}, nil + } + filterOut := !opts.Filter(queryVec, queryVec, queryUid) + candidateK := maxResults + if opts.EfOverride > candidateK { + candidateK = opts.EfOverride + } + lr, err := ph.searchPersistentLayer( + c, ph.maxLevels-1, queryUid, queryVec, queryVec, filterOut, candidateK, opts.Filter) + if err != nil { + return []uint64{}, err + } + res := make([]uint64, 0, maxResults) + for _, n := range lr.neighbors { + if maxResults == 0 { + break + } + if n.filteredOut { + continue + } + if opts.DistanceThreshold != nil { + th := *opts.DistanceThreshold + switch ph.simType.indexType { + case Euclidean: + if float64(n.value) > th { + continue + } + case Cosine: + if float64(1.0)-float64(n.value) > th { + continue + } + default: + } + } + res = append(res, n.index) + if len(res) >= maxResults { + break + } + } + return res, nil +} + +// SearchWithUid searches the HNSW graph for the nearest neighbors of the query UID // and returns the traversal path and the nearest neighbors func (ph *persistentHNSW[T]) SearchWithUid(_ context.Context, c index.CacheType, queryUid uint64, maxResults int, filter index.SearchFilter[T]) (nnUids []uint64, err error) { @@ -275,9 +438,9 @@ func (ph *persistentHNSW[T]) SearchWithUid(_ context.Context, c index.CacheType, shouldFilterOutQueryVec := !filter(queryVec, queryVec, queryUid) - // how normal search works is by cotinuously searching higher layers - // for the best entry node to the last layer since we already know the - // best entry node (since it already exists in the lowest level), we + // How normal search works is by continuously searching higher layers + // for the best entry node to the last layer. Since we already know the + // best entry node (it already exists in the lowest level), we // can just search the last layer and return the results. r, err := ph.searchPersistentLayer( c, ph.maxLevels-1, queryUid, queryVec, queryVec, @@ -389,7 +552,7 @@ func (ph *persistentHNSW[T]) SearchWithPath( return r, nil } -// InsertToPersistentStorage inserts a node into the hnsw graph and returns the +// InsertToPersistentStorage inserts a node into the HNSW graph and returns the // traversal path and the edges created func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, inUuid uint64, inVec []T) ([]*index.KeyValue, error) { @@ -401,7 +564,7 @@ func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, return edges, err } -// InsertToPersistentStorage inserts a node into the hnsw graph and returns the +// InsertToPersistentStorage inserts a node into the HNSW graph and returns the // traversal path and the edges created func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, inUuid uint64, inVec []T) ([]minPersistentHeapElement[T], []*index.KeyValue, error) { diff --git a/tok/index/index.go b/tok/index/index.go index e0a62255ce1..edc203fb89c 100644 --- a/tok/index/index.go +++ b/tok/index/index.go @@ -118,6 +118,35 @@ type VectorIndex[T c.Float] interface { Insert(ctx context.Context, c CacheType, uuid uint64, vec []T) ([]*KeyValue, error) } +// VectorIndexOptions carries optional, per-call search tuning parameters. +// Zero values mean "no override". +type VectorIndexOptions[T c.Float] struct { + // EfOverride, when > 0, overrides the search breadth (ef) for this call. + // Implementations should apply this to upper layers and use max(k, ef) for + // the bottom layer candidate size, then return the best k. + EfOverride int + + // DistanceThreshold, when non-nil, filters out neighbors whose metric-domain + // distance exceeds the given threshold. Semantics depend on the index metric: + // - Euclidean: direct Euclidean distance (not squared) + // - Cosine: cosine distance in [0,2] (1 - cosine_similarity) + // - Dot product: undefined; implementations may ignore + DistanceThreshold *float64 + + // Filter allows callers to pass a SearchFilter; if nil, AcceptAll should be used. + Filter SearchFilter[T] +} + +// OptionalSearchOptions adds per-call search controls without breaking existing APIs. +// Implementations that support these may choose to ignore unsupported fields. +type OptionalSearchOptions[T c.Float] interface { + SearchWithOptions(ctx context.Context, c CacheType, query []T, + maxResults int, opts VectorIndexOptions[T]) ([]uint64, error) + + SearchWithUidAndOptions(ctx context.Context, c CacheType, queryUid uint64, + maxResults int, opts VectorIndexOptions[T]) ([]uint64, error) +} + // A Txn is an interface representation of a persistent storage transaction, // where multiple operations are performed on a database type Txn interface { diff --git a/tok/index/search_path.go b/tok/index/search_path.go index 1c247e926f5..efdcc7e7ad2 100644 --- a/tok/index/search_path.go +++ b/tok/index/search_path.go @@ -5,11 +5,11 @@ package index -// SearchPathResult is the return-type for the optional +// SearchPathResult is the return type for the optional // SearchWithPath function for a VectorIndex // (by way of extending OptionalIndexSupport). type SearchPathResult struct { - // The collection of nearest-neighbors in sorted order after filtlering + // The collection of nearest neighbors in sorted order after filtering // out neighbors that fail any Filter criteria. Neighbors []uint64 // The path from the start of search to the closest neighbor vector. @@ -19,8 +19,8 @@ type SearchPathResult struct { Metrics map[string]uint64 } -// NewSearchPathResult() provides an initialized (empty) *SearchPathResult. -// The attributes will be non-nil, but empty. +// NewSearchPathResult provides an initialised (empty) *SearchPathResult. +// The attributes will be non‑nil but empty. func NewSearchPathResult() *SearchPathResult { return &SearchPathResult{ Neighbors: []uint64{}, diff --git a/worker/task.go b/worker/task.go index ba7e859572f..2a23bb1d05f 100644 --- a/worker/task.go +++ b/worker/task.go @@ -368,12 +368,29 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er return err } var nnUids []uint64 - if srcFn.vectorInfo != nil { - nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, - int(numNeighbors), index.AcceptAll[float32]) + // Build optional search options if provided + filter := index.AcceptAll[float32] + opts := index.VectorIndexOptions[float32]{Filter: filter} + if srcFn.vsEfOverride > 0 { + opts.EfOverride = srcFn.vsEfOverride + } + if srcFn.vsDistanceThreshold != nil { + opts.DistanceThreshold = srcFn.vsDistanceThreshold + } + if o, ok := indexer.(index.OptionalSearchOptions[float32]); ok && (opts.EfOverride > 0 || opts.DistanceThreshold != nil) { + if srcFn.vectorInfo != nil { + nnUids, err = o.SearchWithOptions(ctx, qc, srcFn.vectorInfo, int(numNeighbors), opts) + } else { + nnUids, err = o.SearchWithUidAndOptions(ctx, qc, srcFn.vectorUid, int(numNeighbors), opts) + } } else { - nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, - int(numNeighbors), index.AcceptAll[float32]) + if srcFn.vectorInfo != nil { + nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, + int(numNeighbors), index.AcceptAll[float32]) + } else { + nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, + int(numNeighbors), index.AcceptAll[float32]) + } } if err != nil && !strings.Contains(err.Error(), hnsw.EmptyHNSWTreeError+": "+badger.ErrKeyNotFound.Error()) { @@ -1792,6 +1809,9 @@ type functionContext struct { atype types.TypeID vectorInfo []float32 vectorUid uint64 + // Optional vector search options parsed from a 3rd arg on similar_to + vsEfOverride int + vsDistanceThreshold *float64 } const ( @@ -2119,13 +2139,49 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { } checkRoot(q, fc) case similarToFn: - if err = ensureArgsCount(q.SrcFunc, 2); err != nil { - return nil, err + // Allow 2 or 3 args: k, vector_or_uid[, options] + if !(len(q.SrcFunc.Args) == 2 || len(q.SrcFunc.Args) == 3) { + return nil, errors.Errorf("Function '%s' requires 2 or 3 arguments, but got %d (%v)", q.SrcFunc.Name, len(q.SrcFunc.Args), q.SrcFunc.Args) } fc.vectorInfo, fc.vectorUid, err = interpretVFloatOrUid(q.SrcFunc.Args[1]) if err != nil { return nil, err } + if len(q.SrcFunc.Args) == 3 { + // Parse simple options: key=value pairs separated by comma or JSON-like {key:val,...} + raw := strings.TrimSpace(q.SrcFunc.Args[2]) + if len(raw) > 0 { + if strings.HasPrefix(raw, "{") && strings.HasSuffix(raw, "}") { + raw = strings.TrimSpace(raw[1 : len(raw)-1]) + } + parts := strings.Split(raw, ",") + for _, p := range parts { + kv := strings.SplitN(p, ":", 2) + if len(kv) != 2 { + kv = strings.SplitN(p, "=", 2) + if len(kv) != 2 { + continue + } + } + k := strings.ToLower(strings.TrimSpace(kv[0])) + v := strings.TrimSpace(kv[1]) + v = strings.Trim(v, "\"'") + switch k { + case "ef": + if n, perr := strconv.ParseInt(v, 10, 32); perr == nil && n > 0 { + fc.vsEfOverride = int(n) + } + case "distance_threshold": + if f, perr := strconv.ParseFloat(v, 64); perr == nil { + fc.vsDistanceThreshold = new(float64) + *fc.vsDistanceThreshold = f + } + default: + // ignore unknown keys silently + } + } + } + } case uidInFn: for _, arg := range q.SrcFunc.Args { uidParsed, err := strconv.ParseUint(arg, 0, 64)