Skip to content

Commit 84112d2

Browse files
authored
Support save and restore of index (#2)
* Normalize to speed up distance calc * Support save and restore of index * fix
1 parent 2c06460 commit 84112d2

File tree

9 files changed

+256
-123
lines changed

9 files changed

+256
-123
lines changed

bruteforce_test.go

Lines changed: 0 additions & 75 deletions
This file was deleted.

dist/dataset.bin

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:dea5f559d6d8be77d77c18f4ac67f62bf0482d3eb4a162bc773e92d2a519e0d0
3+
size 6996220

dist/dataset.gob

Lines changed: 0 additions & 3 deletions
This file was deleted.

example/main.go

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package main
22

33
import (
44
"bufio"
5-
"encoding/gob"
65
"fmt"
76
"math"
87
"os"
@@ -21,13 +20,7 @@ func main() {
2120
defer m.Close()
2221

2322
// Load a pre-embedded dataset and create an exact search index
24-
data, _ := loadDataset("../dist/dataset.gob")
25-
index := search.NewIndex[string]()
26-
27-
// Embed the sentences and calculate similarities
28-
for _, v := range data {
29-
index.Add(v.Vector, v.Pair[0]) // use m.EmbedText() for real-time embedding
30-
}
23+
index := loadIndex("../dist/dataset.bin")
3124

3225
r := bufio.NewReader(os.Stdin)
3326
for {
@@ -63,27 +56,12 @@ func main() {
6356
}
6457
}
6558

66-
type record struct {
67-
Pair [2]string `gob:"pair"`
68-
Rank float64 `gob:"rank"`
69-
Label string `gob:"label"`
70-
Vector []float32 `gob:"vector"`
71-
}
72-
73-
func loadDataset(path string) ([]record, error) {
74-
file, err := os.Open(path)
75-
if err != nil {
76-
return nil, err
77-
}
78-
defer file.Close()
79-
80-
var data []record
81-
r := gob.NewDecoder(file)
82-
if err := r.Decode(&data); err != nil {
83-
return nil, err
59+
func loadIndex(path string) *search.Index[string] {
60+
index := search.NewIndex[string]()
61+
if err := index.ReadFile(path); err != nil {
62+
panic(err)
8463
}
85-
86-
return data, nil
64+
return index
8765
}
8866

8967
/*

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ go 1.23
44

55
require (
66
github.com/ebitengine/purego v0.8.1
7+
github.com/kelindar/iostream v1.4.0
78
github.com/klauspost/cpuid/v2 v2.2.8
89
github.com/stretchr/testify v1.9.0
910
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
33
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
44
github.com/ebitengine/purego v0.8.1 h1:sdRKd6plj7KYW33EH5As6YKfe8m9zbN9JMrOjNVF/BE=
55
github.com/ebitengine/purego v0.8.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
6+
github.com/kelindar/iostream v1.4.0 h1:ELKlinnM/K3GbRp9pYhWuZOyBxMMlYAfsOP+gauvZaY=
7+
github.com/kelindar/iostream v1.4.0/go.mod h1:MkjMuVb6zGdPQVdwLnFRO0xOTOdDvBWTztFmjRDQkXk=
68
github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM=
79
github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
810
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=

bruteforce.go renamed to index.go

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ type entry[T any] struct {
1919

2020
// Result represents a search result.
2121
type Result[T any] struct {
22-
entry[T]
2322
Relevance float64 // The relevance of the result
23+
Value T // The value of the result
2424
}
2525

2626
// Index represents a brute-force search index, returning exact results.
@@ -31,46 +31,52 @@ type Index[T any] struct {
3131
// NewIndex creates a new exact search index.
3232
func NewIndex[T any]() *Index[T] {
3333
return &Index[T]{
34-
arr: make([]entry[T], 0),
34+
arr: make([]entry[T], 0, 512),
3535
}
3636
}
3737

38+
// Len returns the number of items in the index.
39+
func (idx *Index[T]) Len() int {
40+
return len(idx.arr)
41+
}
42+
3843
// Add adds a new vector to the search index.
39-
func (b *Index[T]) Add(vx Vector, item T) {
44+
func (idx *Index[T]) Add(vx Vector, item T) {
4045
normalize(vx)
41-
42-
b.arr = append(b.arr, entry[T]{
46+
idx.arr = append(idx.arr, entry[T]{
4347
Vector: vx,
4448
Value: item,
4549
})
4650
}
4751

4852
// Search searches the index for the k-nearest neighbors of the query vector.
49-
func (b *Index[T]) Search(query Vector, k int) []Result[T] {
53+
func (idx *Index[T]) Search(query Vector, k int) []Result[T] {
5054
if k <= 0 {
5155
return nil
5256
}
5357

54-
// Normalize and quantize the query vector
58+
// Normalize the query vector
5559
normalize(query)
5660

57-
var relevance float64
61+
var r float64
5862
dst := make(minheap[T], 0, k)
59-
for _, v := range b.arr {
60-
simd.DotProduct(&relevance, query, v.Vector)
61-
result := Result[T]{
62-
entry: v,
63-
Relevance: relevance,
64-
}
63+
for _, v := range idx.arr {
64+
simd.DotProduct(&r, query, v.Vector)
6565

6666
// If the heap is not full, add the result, otherwise replace
6767
// the minimum element
6868
switch {
6969
case dst.Len() < k:
70-
dst.Push(result)
71-
case result.Relevance > dst[0].Relevance:
70+
dst.Push(Result[T]{
71+
Value: v.Value,
72+
Relevance: r,
73+
})
74+
case r > dst[0].Relevance:
7275
dst.Pop()
73-
dst.Push(result)
76+
dst.Push(Result[T]{
77+
Value: v.Value,
78+
Relevance: r,
79+
})
7480
}
7581
}
7682

index_codec.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for details.
3+
4+
package search
5+
6+
import (
7+
"compress/flate"
8+
"fmt"
9+
"io"
10+
"os"
11+
12+
"github.com/kelindar/iostream"
13+
)
14+
15+
// WriteTo writes the index to a writer.
16+
func (b *Index[T]) WriteTo(dst io.Writer) (int64, error) {
17+
w := iostream.NewWriter(dst)
18+
i := w.Offset()
19+
20+
// Write version
21+
if err := w.WriteUint8(1); err != nil {
22+
return 0, err
23+
}
24+
25+
// Write the index
26+
err := w.WriteRange(len(b.arr), func(i int, w *iostream.Writer) error {
27+
if err := w.WriteFloat32s(b.arr[i].Vector); err != nil {
28+
return err
29+
}
30+
31+
// Write the value (optional)
32+
switch v := any(b.arr[i].Value).(type) {
33+
case string:
34+
return w.WriteString(v)
35+
case []byte:
36+
return w.WriteBytes(v)
37+
default:
38+
return nil
39+
}
40+
})
41+
42+
return w.Offset() - i, err
43+
}
44+
45+
// ReadFrom reads the index from a reader.
46+
func (b *Index[T]) ReadFrom(src io.Reader) (int64, error) {
47+
r := iostream.NewReader(src)
48+
s := r.Offset()
49+
50+
// Read version
51+
version, err := r.ReadUint8()
52+
if err != nil {
53+
return 0, err
54+
}
55+
56+
if version != 1 {
57+
return 0, fmt.Errorf("unsupported version: %d", version)
58+
}
59+
60+
var length uint64
61+
if length, err = r.ReadUvarint(); err != nil {
62+
return r.Offset() - s, err
63+
}
64+
65+
// Allocate space for the entries
66+
b.arr = make([]entry[T], length)
67+
for i := 0; i < int(length); i++ {
68+
69+
// Read the vector
70+
if b.arr[i].Vector, err = r.ReadFloat32s(); err != nil {
71+
return r.Offset() - s, err
72+
}
73+
74+
// Read the value (optional)
75+
switch any(b.arr[i].Value).(type) {
76+
case string:
77+
v, err := r.ReadString()
78+
if err != nil {
79+
return r.Offset() - s, err
80+
}
81+
b.arr[i].Value = any(v).(T)
82+
83+
case []byte:
84+
v, err := r.ReadBytes()
85+
if err != nil {
86+
return r.Offset() - s, err
87+
}
88+
b.arr[i].Value = any(v).(T)
89+
}
90+
}
91+
92+
return r.Offset() - s, nil
93+
}
94+
95+
// ---------------------------------- File ----------------------------------
96+
97+
// WriteFile writes the index into a flate-compressed binary file.
98+
func (idx *Index[T]) WriteFile(filename string) error {
99+
file, err := os.Create(filename)
100+
if err != nil {
101+
return err
102+
}
103+
104+
defer file.Close()
105+
writer, err := flate.NewWriter(file, flate.DefaultCompression)
106+
if err != nil {
107+
return err
108+
}
109+
110+
// WriteTo the underlying writer
111+
defer writer.Close()
112+
_, err = idx.WriteTo(writer)
113+
return err
114+
}
115+
116+
// ReadFile reads the index from a flate-compressed binary file.
117+
func (idx *Index[T]) ReadFile(filename string) error {
118+
file, err := os.Open(filename)
119+
if err != nil {
120+
return err
121+
}
122+
123+
defer file.Close()
124+
_, err = idx.ReadFrom(flate.NewReader(file))
125+
return err
126+
}

0 commit comments

Comments
 (0)