@@ -44,10 +44,11 @@ func TestPartition(t *testing.T) {
44
44
valueBytes20b := ValueBytes {11 , 12 }
45
45
46
46
var workspace workspace.T
47
- quantizer := quantize .NewUnQuantizer (2 , vecpb .L2SquaredDistance )
47
+ unquantizer := quantize .NewUnQuantizer (2 , vecpb .L2SquaredDistance )
48
+ rabitq := quantize .NewRaBitQuantizer (2 , 42 , vecpb .InnerProductDistance )
48
49
49
50
// newTestPartition creates a partition with 2 vectors.
50
- newTestPartition := func () * Partition {
51
+ newTestPartition := func (quantizer quantize. Quantizer ) * Partition {
51
52
vectors := vector .MakeSet (2 )
52
53
vectors .Add (vec10 )
53
54
vectors .Add (vec20 )
@@ -69,7 +70,7 @@ func TestPartition(t *testing.T) {
69
70
70
71
t .Run ("test Init" , func (t * testing.T ) {
71
72
// Validate that Init sets same values.
72
- partition := newTestPartition ()
73
+ partition := newTestPartition (unquantizer )
73
74
var partition2 Partition
74
75
partition2 .Init (
75
76
* partition .Metadata (),
@@ -82,7 +83,7 @@ func TestPartition(t *testing.T) {
82
83
})
83
84
84
85
t .Run ("test Clone" , func (t * testing.T ) {
85
- partition := newTestPartition ()
86
+ partition := newTestPartition (unquantizer )
86
87
cloned := partition .Clone ()
87
88
require .Equal (t , partition , cloned )
88
89
@@ -101,7 +102,7 @@ func TestPartition(t *testing.T) {
101
102
})
102
103
103
104
t .Run ("test Add" , func (t * testing.T ) {
104
- partition := newTestPartition ()
105
+ partition := newTestPartition (unquantizer )
105
106
require .True (t , partition .Add (& workspace , vec40 , childKey40 , valueBytes40 , true /* overwrite */ ))
106
107
require .Equal (t , 4 , partition .Count ())
107
108
require .Equal (t , []ChildKey {childKey10 , childKey20 , childKey30 , childKey40 }, partition .ChildKeys ())
@@ -111,14 +112,14 @@ func TestPartition(t *testing.T) {
111
112
checkPartitionMetadata (t , partition .Metadata (), Level (1 ), vector.T {4 , 3.33 })
112
113
113
114
// Add vector with duplicate key and overwrite=false. Expect no-op.
114
- partition = newTestPartition ()
115
+ partition = newTestPartition (unquantizer )
115
116
require .False (t , partition .Add (& workspace , vec20b , childKey20 , valueBytes20b , false /* overwrite */ ))
116
117
require .Equal (t , 3 , partition .Count ())
117
118
require .Equal (t , []ValueBytes {valueBytes10 , valueBytes20 , valueBytes30 }, partition .ValueBytes ())
118
119
119
120
// Add vector with duplicate key and overwrite=true. Expect value to be
120
121
// updated.
121
- partition = newTestPartition ()
122
+ partition = newTestPartition (unquantizer )
122
123
require .False (t , partition .Add (& workspace , vec20b , childKey20 , valueBytes20b , true /* overwrite */ ))
123
124
require .Equal (t , 3 , partition .Count ())
124
125
require .Equal (t , []ChildKey {childKey10 , childKey30 , childKey20 }, partition .ChildKeys ())
@@ -129,7 +130,7 @@ func TestPartition(t *testing.T) {
129
130
t .Run ("test AddSet" , func (t * testing.T ) {
130
131
// Create empty partition.
131
132
metadata := PartitionMetadata {Level : 1 , Centroid : vector.T {4 , 3 }}
132
- partition := CreateEmptyPartition (quantizer , metadata )
133
+ partition := CreateEmptyPartition (unquantizer , metadata )
133
134
134
135
// Add empty set.
135
136
vectors := vector .MakeSet (2 )
@@ -188,7 +189,7 @@ func TestPartition(t *testing.T) {
188
189
t .Run ("test Search" , func (t * testing.T ) {
189
190
// Search empty partition.
190
191
metadata := PartitionMetadata {Level : LeafLevel , Centroid : vector.T {4 , 3 }}
191
- partition := CreateEmptyPartition (quantizer , metadata )
192
+ partition := CreateEmptyPartition (unquantizer , metadata )
192
193
require .Equal (t , Level (1 ), partition .Level ())
193
194
194
195
searchSet := SearchSet {MaxResults : 1 }
@@ -198,7 +199,7 @@ func TestPartition(t *testing.T) {
198
199
require .Equal (t , SearchResults (nil ), results )
199
200
200
201
// Search partition with 5 vectors.
201
- partition = newTestPartition ()
202
+ partition = newTestPartition (unquantizer )
202
203
vectors := vector .MakeSet (2 )
203
204
vectors .Add (vec40 )
204
205
vectors .Add (vec50 )
@@ -219,8 +220,24 @@ func TestPartition(t *testing.T) {
219
220
require .Equal (t , SearchResults {result1 , result2 , result3 }, results )
220
221
})
221
222
223
+ t .Run ("test Search with IncludeCentroidDistances" , func (t * testing.T ) {
224
+ // Search partition with 3 vectors.
225
+ partition := newTestPartition (rabitq )
226
+
227
+ searchSet := SearchSet {MaxResults : 2 , IncludeCentroidDistances : true }
228
+ _ = partition .Search (& workspace , RootKey , vector.T {1 , 1 }, & searchSet )
229
+ result1 := SearchResult {
230
+ QueryDistance : - 11.52 , ErrorBound : 8.96 , CentroidDistance : - 8.45 , ParentPartitionKey : 1 ,
231
+ ChildKey : childKey30 , ValueBytes : valueBytes30 }
232
+ result2 := SearchResult {
233
+ QueryDistance : - 6.1 , ErrorBound : 4.48 , CentroidDistance : - 5.12 , ParentPartitionKey : 1 ,
234
+ ChildKey : childKey20 , ValueBytes : valueBytes20 }
235
+ results := roundResults (searchSet .PopResults (), 2 )
236
+ require .Equal (t , SearchResults {result1 , result2 }, results )
237
+ })
238
+
222
239
t .Run ("test ReplaceWithLast" , func (t * testing.T ) {
223
- partition := newTestPartition ()
240
+ partition := newTestPartition (unquantizer )
224
241
partition .ReplaceWithLast (0 )
225
242
require .Equal (t , 2 , partition .Count ())
226
243
require .Equal (t , []ChildKey {childKey30 , childKey20 }, partition .ChildKeys ())
@@ -237,7 +254,7 @@ func TestPartition(t *testing.T) {
237
254
})
238
255
239
256
t .Run ("test ReplaceWithLastByKey and Find" , func (t * testing.T ) {
240
- partition := newTestPartition ()
257
+ partition := newTestPartition (unquantizer )
241
258
require .Equal (t , 0 , partition .Find (childKey10 ))
242
259
require .Equal (t , 2 , partition .Find (childKey30 ))
243
260
require .True (t , partition .ReplaceWithLastByKey (childKey10 ))
@@ -252,7 +269,7 @@ func TestPartition(t *testing.T) {
252
269
})
253
270
254
271
t .Run ("test Clear" , func (t * testing.T ) {
255
- partition := newTestPartition ()
272
+ partition := newTestPartition (unquantizer )
256
273
require .Equal (t , 3 , partition .Clear ())
257
274
require .Equal (t , 0 , partition .Count ())
258
275
require .Equal (t , []ChildKey {}, partition .ChildKeys ())
@@ -273,6 +290,7 @@ func roundResults(results SearchResults, prec int) SearchResults {
273
290
result := & results [i ]
274
291
result .QueryDistance = float32 (scalar .Round (float64 (result .QueryDistance ), prec ))
275
292
result .ErrorBound = float32 (scalar .Round (float64 (result .ErrorBound ), prec ))
293
+ result .CentroidDistance = float32 (scalar .Round (float64 (result .CentroidDistance ), prec ))
276
294
result .Vector = testutils .RoundFloats (result .Vector , prec )
277
295
}
278
296
return results
0 commit comments