Skip to content

Commit 47eea43

Browse files
bhshkhtelpirion
andauthored
feat(firestore): Distance result field and distance threshold in vector search (#4362)
* feat(firestore): Vector search * test(firestore): Clean up test resources * refactor(firestore): Refactoring tests * feat(firestore): Distance result field and distance threshold in vector search * feat(firestore): Updating branch * feat(firestore): Add link to documentation --------- Co-authored-by: Eric Schmidt <[email protected]>
1 parent c4ce89c commit 47eea43

7 files changed

+307
-0
lines changed

firestore/vector_search_basic.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ func vectorSearchBasic(w io.Writer, projectID string) error {
3636
collection := client.Collection("coffee-beans")
3737

3838
// Requires a vector index
39+
// https://firebase.google.com/docs/firestore/vector-search#create_and_manage_vector_indexes
3940
vectorQuery := collection.FindNearest("embedding_field",
4041
[]float32{3.0, 1.0, 2.0},
4142
5,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package firestore
16+
17+
// [START firestore_vector_search_basic]
18+
import (
19+
"context"
20+
"fmt"
21+
"io"
22+
23+
"cloud.google.com/go/firestore"
24+
)
25+
26+
func vectorSearchDistanceThreshold(w io.Writer, projectID string) error {
27+
ctx := context.Background()
28+
29+
client, err := firestore.NewClient(ctx, projectID)
30+
if err != nil {
31+
return fmt.Errorf("firestore.NewClient: %w", err)
32+
}
33+
defer client.Close()
34+
35+
collection := client.Collection("coffee-beans")
36+
37+
// Requires a vector index
38+
// https://firebase.google.com/docs/firestore/vector-search#create_and_manage_vector_indexes
39+
vectorQuery := collection.FindNearest("embedding_field",
40+
[]float32{3.0, 1.0, 2.0},
41+
10,
42+
firestore.DistanceMeasureEuclidean,
43+
&firestore.FindNearestOptions{
44+
DistanceThreshold: firestore.Ptr[float64](4.5),
45+
})
46+
47+
docs, err := vectorQuery.Documents(ctx).GetAll()
48+
if err != nil {
49+
fmt.Fprintf(w, "failed to get vector query results: %v", err)
50+
return err
51+
}
52+
53+
for _, doc := range docs {
54+
fmt.Fprintln(w, doc.Data()["name"])
55+
}
56+
return nil
57+
}
58+
59+
// [END firestore_vector_search_basic]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package firestore
16+
17+
import (
18+
"bytes"
19+
"os"
20+
"strings"
21+
"testing"
22+
)
23+
24+
func TestVectorSearchDistanceThreshold(t *testing.T) {
25+
projectID := os.Getenv("GOLANG_SAMPLES_FIRESTORE_PROJECT")
26+
if projectID == "" {
27+
t.Skip("Skipping firestore test. Set GOLANG_SAMPLES_FIRESTORE_PROJECT.")
28+
}
29+
30+
buf := new(bytes.Buffer)
31+
if err := vectorSearchDistanceThreshold(buf, projectID); err != nil {
32+
t.Errorf("vectorSearchDistanceThreshold: %v", err)
33+
}
34+
35+
// Compare console outputs
36+
got := buf.String()
37+
want := "Sleepy coffee beans\n" +
38+
"Kahawa coffee beans\n"
39+
if !strings.Contains(got, want) {
40+
t.Errorf("got %q, want %q", got, want)
41+
}
42+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package firestore
16+
17+
// [START firestore_vector_search_distance_result_field]
18+
import (
19+
"context"
20+
"fmt"
21+
"io"
22+
23+
"cloud.google.com/go/firestore"
24+
)
25+
26+
func vectorSearchDistanceResultField(w io.Writer, projectID string) error {
27+
ctx := context.Background()
28+
29+
client, err := firestore.NewClient(ctx, projectID)
30+
if err != nil {
31+
return fmt.Errorf("firestore.NewClient: %w", err)
32+
}
33+
defer client.Close()
34+
35+
collection := client.Collection("coffee-beans")
36+
37+
// Requires a vector index
38+
// https://firebase.google.com/docs/firestore/vector-search#create_and_manage_vector_indexes
39+
vectorQuery := collection.FindNearest("embedding_field",
40+
[]float32{3.0, 1.0, 2.0},
41+
10,
42+
firestore.DistanceMeasureEuclidean,
43+
&firestore.FindNearestOptions{
44+
DistanceResultField: "vector_distance",
45+
})
46+
47+
docs, err := vectorQuery.Documents(ctx).GetAll()
48+
if err != nil {
49+
fmt.Fprintf(w, "failed to get vector query results: %v", err)
50+
return err
51+
}
52+
53+
for _, doc := range docs {
54+
fmt.Fprintf(w, "%v, Distance: %v\n", doc.Data()["name"], doc.Data()["vector_distance"])
55+
}
56+
return nil
57+
}
58+
59+
// [END firestore_vector_search_distance_result_field]
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package firestore
16+
17+
// [START firestore_vector_search_distance_result_field]
18+
import (
19+
"context"
20+
"fmt"
21+
"io"
22+
23+
"cloud.google.com/go/firestore"
24+
)
25+
26+
func vectorSearchDistanceResultFieldMasked(w io.Writer, projectID string) error {
27+
ctx := context.Background()
28+
29+
client, err := firestore.NewClient(ctx, projectID)
30+
if err != nil {
31+
return fmt.Errorf("firestore.NewClient: %w", err)
32+
}
33+
defer client.Close()
34+
35+
collection := client.Collection("coffee-beans")
36+
37+
// Requires a vector index
38+
// https://firebase.google.com/docs/firestore/vector-search#create_and_manage_vector_indexes
39+
vectorQuery := collection.Select("color", "vector_distance").
40+
FindNearest("embedding_field",
41+
[]float32{3.0, 1.0, 2.0},
42+
10,
43+
firestore.DistanceMeasureEuclidean,
44+
&firestore.FindNearestOptions{
45+
DistanceResultField: "vector_distance",
46+
})
47+
48+
docs, err := vectorQuery.Documents(ctx).GetAll()
49+
if err != nil {
50+
fmt.Fprintf(w, "failed to get vector query results: %v", err)
51+
return err
52+
}
53+
54+
for _, doc := range docs {
55+
fmt.Fprintf(w, "%v, Distance: %v\n", doc.Data()["color"], doc.Data()["vector_distance"])
56+
}
57+
return nil
58+
}
59+
60+
// [END firestore_vector_search_distance_result_field]
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package firestore
16+
17+
import (
18+
"bytes"
19+
"os"
20+
"strings"
21+
"testing"
22+
)
23+
24+
func TestVectorSearchDistanceResultFieldMasked(t *testing.T) {
25+
projectID := os.Getenv("GOLANG_SAMPLES_FIRESTORE_PROJECT")
26+
if projectID == "" {
27+
t.Skip("Skipping firestore test. Set GOLANG_SAMPLES_FIRESTORE_PROJECT.")
28+
}
29+
30+
buf := new(bytes.Buffer)
31+
if err := vectorSearchDistanceResultFieldMasked(buf, projectID); err != nil {
32+
t.Errorf("vectorSearchDistanceResultFieldMasked: %v", err)
33+
}
34+
35+
// Compare console outputs
36+
got := buf.String()
37+
want := "red, Distance: 0\n" +
38+
"red, Distance: 2.449489742783178\n" +
39+
"brown, Distance: 5.744562646538029\n"
40+
if !strings.Contains(got, want) {
41+
t.Errorf("got %q, want %q", got, want)
42+
}
43+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package firestore
16+
17+
import (
18+
"bytes"
19+
"os"
20+
"strings"
21+
"testing"
22+
)
23+
24+
func TestVectorSearchDistanceResultField(t *testing.T) {
25+
projectID := os.Getenv("GOLANG_SAMPLES_FIRESTORE_PROJECT")
26+
if projectID == "" {
27+
t.Skip("Skipping firestore test. Set GOLANG_SAMPLES_FIRESTORE_PROJECT.")
28+
}
29+
30+
buf := new(bytes.Buffer)
31+
if err := vectorSearchDistanceResultField(buf, projectID); err != nil {
32+
t.Errorf("vectorSearchDistanceResultField: %v", err)
33+
}
34+
35+
// Compare console outputs
36+
got := buf.String()
37+
want := "Sleepy coffee beans, Distance: 0\n" +
38+
"Kahawa coffee beans, Distance: 2.449489742783178\n" +
39+
"Owl coffee beans, Distance: 5.744562646538029\n"
40+
if !strings.Contains(got, want) {
41+
t.Errorf("got %q, want %q", got, want)
42+
}
43+
}

0 commit comments

Comments
 (0)