@@ -76,22 +76,40 @@ class IVFFlatIndex(Index):
7676 URI of datataset
7777 dtype: numpy.dtype
7878 datatype float32 or uint8
79+ memory_budget: int
80+ Main memory budget. If not provided no memory budget is applied.
7981 """
8082
81- def __init__ (self , uri , dtype : np .dtype ):
83+ def __init__ (
84+ self , uri , dtype : np .dtype , memory_budget : int = - 1 , ctx : "Ctx" = None
85+ ):
8286 self .parts_db_uri = os .path .join (uri , "parts.tdb" )
8387 self .centroids_uri = os .path .join (uri , "centroids.tdb" )
8488 self .index_uri = os .path .join (uri , "index.tdb" )
8589 self .ids_uri = os .path .join (uri , "ids.tdb" )
8690 self .dtype = dtype
91+ self .memory_budget = memory_budget
92+ self .ctx = ctx
93+ if ctx is None :
94+ self .ctx = Ctx ({})
95+
96+ # TODO pass in a context
97+ if self .memory_budget == - 1 :
98+ self ._db = load_as_matrix (self .parts_db_uri )
99+ self ._ids = read_vector_u64 (self .ctx , self .ids_uri )
87100
88- ctx = Ctx ({}) # TODO pass in a context
89- self ._db = load_as_matrix (self .parts_db_uri )
90101 self ._centroids = load_as_matrix (self .centroids_uri )
91- self ._index = read_vector_u64 (ctx , self .index_uri )
92- self ._ids = read_vector_u64 (ctx , self .ids_uri )
102+ self ._index = read_vector_u64 (self .ctx , self .index_uri )
93103
94- def query (self , targets : np .ndarray , k = 10 , nqueries = 10 , nthreads = 8 , nprobe = 1 ):
104+ def query (
105+ self ,
106+ targets : np .ndarray ,
107+ k = 10 ,
108+ nqueries = 10 ,
109+ nthreads = 8 ,
110+ nprobe = 1 ,
111+ use_nuv_implementation : bool = False ,
112+ ):
95113 """
96114 Open a flat index
97115
@@ -107,21 +125,42 @@ def query(self, targets: np.ndarray, k=10, nqueries=10, nthreads=8, nprobe=1):
107125 Number of threads to use for queyr
108126 nprobe: int
109127 number of probes
128+ use_nuv_implementation: bool
129+ wether to use the nuv query implementation. Default: False
110130 """
111131 assert targets .dtype == np .float32
112132
113133 targets_m = array_to_matrix (targets )
134+ if self .memory_budget == - 1 :
135+ r = ivf_query_ram (
136+ self .dtype ,
137+ self ._db ,
138+ self ._centroids ,
139+ targets_m ,
140+ self ._index ,
141+ self ._ids ,
142+ nprobe = nprobe ,
143+ k_nn = k ,
144+ nth = True , # ??
145+ nthreads = nthreads ,
146+ ctx = self .ctx ,
147+ use_nuv_implementation = use_nuv_implementation ,
148+ )
149+ else :
150+ r = ivf_query (
151+ self .dtype ,
152+ self .parts_db_uri ,
153+ self ._centroids ,
154+ targets_m ,
155+ self ._index ,
156+ self .ids_uri ,
157+ nprobe = nprobe ,
158+ k_nn = k ,
159+ memory_budget = self .memory_budget ,
160+ nth = True , # ??
161+ nthreads = nthreads ,
162+ ctx = self .ctx ,
163+ use_nuv_implementation = use_nuv_implementation ,
164+ )
114165
115- r = query_kmeans (
116- self ._db .dtype ,
117- self ._db ,
118- self ._centroids ,
119- targets_m ,
120- self ._index ,
121- self ._ids ,
122- nprobe = nprobe ,
123- k_nn = k ,
124- nth = True , # ??
125- nthreads = nthreads ,
126- )
127166 return np .array (r )
0 commit comments