2626
2727# Local dependencies
2828from .common import \
29+ isapprox , \
2930 test_data_svs , \
3031 test_data_vecs , \
32+ test_data_dims , \
3133 test_graph , \
32- test_vamana_config
34+ test_vamana_config , \
35+ test_close_lvq
3336
3437DEBUG = False ;
3538
@@ -38,9 +41,57 @@ class ReconstructionTester(unittest.TestCase):
3841 Test the reconstruction interface for indexex.
3942 """
4043 def _get_loaders (self , loader : svs .VectorDataLoader ):
44+ sequential = svs .LVQStrategy .Sequential
45+ turbo = svs .LVQStrategy .Turbo
46+
4147 return [
4248 # Uncompressed
4349 loader ,
50+ # LVQ
51+ svs .LVQLoader (loader , primary = 8 , padding = 0 ),
52+ svs .LVQLoader (loader , primary = 4 , padding = 0 ),
53+ svs .LVQLoader (
54+ loader , primary = 4 , residual = 4 , strategy = sequential , padding = 0
55+ ),
56+ svs .LVQLoader (
57+ loader , primary = 4 , residual = 4 , strategy = turbo , padding = 0
58+ ),
59+ svs .LVQLoader (
60+ loader , primary = 4 , residual = 8 , strategy = sequential , padding = 0
61+ ),
62+ svs .LVQLoader (
63+ loader , primary = 4 , residual = 8 , strategy = turbo , padding = 0
64+ ),
65+ svs .LVQLoader (loader , primary = 8 , residual = 8 , padding = 0 ),
66+
67+ # LeanVec
68+ svs .LeanVecLoader (
69+ loader ,
70+ leanvec_dims = 64 ,
71+ primary_kind = svs .LeanVecKind .float32 ,
72+ secondary_kind = svs .LeanVecKind .float32 ,
73+ ),
74+ svs .LeanVecLoader (
75+ loader ,
76+ leanvec_dims = 64 ,
77+ primary_kind = svs .LeanVecKind .lvq4 ,
78+ secondary_kind = svs .LeanVecKind .lvq8 ,
79+ alignment = 0
80+ ),
81+ svs .LeanVecLoader (
82+ loader ,
83+ leanvec_dims = 64 ,
84+ primary_kind = svs .LeanVecKind .lvq8 ,
85+ secondary_kind = svs .LeanVecKind .lvq8 ,
86+ alignment = 0
87+ ),
88+ svs .LeanVecLoader (
89+ loader ,
90+ leanvec_dims = 64 ,
91+ primary_kind = svs .LeanVecKind .lvq8 ,
92+ secondary_kind = svs .LeanVecKind .float16 ,
93+ alignment = 0
94+ ),
4495 ]
4596
4697 def _test_misc (self , loader : svs .VectorDataLoader , data ):
@@ -68,6 +119,30 @@ def _test_misc(self, loader: svs.VectorDataLoader, data):
68119 vamana .reconstruct (np .zeros ((10 , 10 ), dtype = np .uint64 )).shape == (10 , 10 , d )
69120 )
70121
122+ def _compare_lvq (self , data , reconstructed , loader : svs .LVQLoader ):
123+ print (f"LVQ: primary = { loader .primary_bits } , residual = { loader .residual_bits } " )
124+ self .assertTrue (isinstance (loader , svs .LVQLoader ))
125+ self .assertTrue (test_close_lvq (
126+ data ,
127+ reconstructed ,
128+ primary_bits = loader .primary_bits ,
129+ residual_bits = loader .residual_bits
130+ ))
131+
132+ def _compare_leanvec (self , data , reconstructed , loader : svs .LeanVecLoader ):
133+ self .assertTrue (isinstance (loader , svs .LeanVecLoader ))
134+ secondary_kind = loader .secondary_kind
135+ if secondary_kind == svs .LeanVecKind .float32 :
136+ self .assertTrue (np .array_equal (data , reconstructed ))
137+ elif secondary_kind == svs .LeanVecKind .float16 :
138+ self .assertTrue (np .allclose (data , reconstructed ))
139+ elif secondary_kind == svs .LeanVecKind .lvq4 :
140+ self .assertTrue (test_close_lvq (data , reconstructed , primary_bits = 4 ))
141+ elif secondary_kind == svs .LeanVecKind .lvq8 :
142+ self .assertTrue (test_close_lvq (data , reconstructed , primary_bits = 8 ))
143+ else :
144+ raise Exception (f"Unknown leanvec kind { secondary_kind } " )
145+
71146 def test_reconstruction (self ):
72147 default_loader = svs .VectorDataLoader (test_data_svs , svs .DataType .float32 )
73148 all_loaders = self ._get_loaders (default_loader )
@@ -88,6 +163,10 @@ def test_reconstruction(self):
88163
89164 if isinstance (loader , svs .VectorDataLoader ):
90165 self .assertTrue (np .array_equal (shuffled_data , r ))
166+ elif isinstance (loader , svs .LVQLoader ):
167+ self ._compare_lvq (shuffled_data , r , loader )
168+ elif isinstance (loader , svs .LeanVecLoader ):
169+ self ._compare_leanvec (shuffled_data , r , loader )
91170 else :
92171 raise Exception (f"Unhandled loader kind: { loader } " )
93172
0 commit comments