|
| 1 | +from __future__ import absolute_import |
| 2 | +import psutil |
| 3 | +import os |
| 4 | +import time |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +import diskannpy |
| 8 | +import fast_refine |
| 9 | + |
| 10 | +from neurips23.streaming.base import BaseStreamingANN |
| 11 | + |
| 12 | +class Zilliz(BaseStreamingANN): |
| 13 | + def __init__(self, metric, index_params): |
| 14 | + self.name = "pyanns" |
| 15 | + if (index_params.get("R")==None): |
| 16 | + print("Error: missing parameter R") |
| 17 | + return |
| 18 | + if (index_params.get("L")==None): |
| 19 | + print("Error: missing parameter L") |
| 20 | + return |
| 21 | + self._index_params = index_params |
| 22 | + self._metric = metric |
| 23 | + |
| 24 | + self.R = index_params.get("R") |
| 25 | + self.L = index_params.get("L") |
| 26 | + self.insert_threads = index_params.get("insert_threads") |
| 27 | + self.consolidate_threads = index_params.get("consolidate_threads") |
| 28 | + self.mx = None |
| 29 | + self.mi = None |
| 30 | + |
| 31 | + def index_name(self): |
| 32 | + return f"R{self.R}_L{self.L}" |
| 33 | + |
| 34 | + def create_index_dir(self, dataset): |
| 35 | + index_dir = os.path.join(os.getcwd(), "data", "indices", "streaming") |
| 36 | + os.makedirs(index_dir, mode=0o777, exist_ok=True) |
| 37 | + index_dir = os.path.join(index_dir, 'pyanns') |
| 38 | + os.makedirs(index_dir, mode=0o777, exist_ok=True) |
| 39 | + index_dir = os.path.join(index_dir, dataset.short_name()) |
| 40 | + os.makedirs(index_dir, mode=0o777, exist_ok=True) |
| 41 | + index_dir = os.path.join(index_dir, self.index_name()) |
| 42 | + os.makedirs(index_dir, mode=0o777, exist_ok=True) |
| 43 | + return index_dir |
| 44 | + |
| 45 | + def translate_dist_fn(self, metric): |
| 46 | + if metric == 'euclidean': |
| 47 | + return 'l2' |
| 48 | + elif metric == 'ip': |
| 49 | + return 'mips' |
| 50 | + else: |
| 51 | + raise Exception('Invalid metric') |
| 52 | + |
| 53 | + def translate_dtype(self, dtype:str): |
| 54 | + return np.uint8 |
| 55 | + |
| 56 | + def setup(self, dtype, max_pts, ndim): |
| 57 | + self.index = diskannpy.DynamicMemoryIndex( |
| 58 | + distance_metric = self.translate_dist_fn(self._metric), |
| 59 | + vector_dtype = self.translate_dtype(dtype), |
| 60 | + max_vectors = max_pts, |
| 61 | + dimensions = ndim, |
| 62 | + graph_degree = self.R, |
| 63 | + complexity=self.L, |
| 64 | + num_threads = self.insert_threads, #to allocate scratch space for up to 64 search threads |
| 65 | + initial_search_complexity = 100 |
| 66 | + ) |
| 67 | + self.refiner = fast_refine.Refiner(ndim, max_pts) |
| 68 | + self.max_pts = max_pts |
| 69 | + print('Index class constructed and ready for update/search') |
| 70 | + self.active_indices = set() |
| 71 | + self.num_unprocessed_deletes = 0 |
| 72 | + |
| 73 | + def quant(self, X, mi, mx): |
| 74 | + return np.round(np.clip((X - mi) / (mx - mi) * 127.0, 0.0, 127.0)).astype('uint8') |
| 75 | + |
| 76 | + def insert(self, X, ids): |
| 77 | + if self.mi is None: |
| 78 | + self.mi = X.min() |
| 79 | + self.mx = X.max() |
| 80 | + |
| 81 | + self.refiner.batch_insert(X, ids) |
| 82 | + X = self.quant(X, self.mi, self.mx) |
| 83 | + self.active_indices.update(ids+1) |
| 84 | + print('#active pts', len(self.active_indices), '#unprocessed deletes', self.num_unprocessed_deletes) |
| 85 | + if len(self.active_indices) + self.num_unprocessed_deletes >= self.max_pts: |
| 86 | + self.index.consolidate_delete() |
| 87 | + self.num_unprocessed_deletes = 0 |
| 88 | + |
| 89 | + self.index.batch_insert(X, ids+1) |
| 90 | + |
| 91 | + def delete(self, ids): |
| 92 | + self.refiner.batch_delete(ids) |
| 93 | + for id in ids: |
| 94 | + self.index.mark_deleted(id+1) |
| 95 | + self.active_indices.difference_update(ids+1) |
| 96 | + self.num_unprocessed_deletes += len(ids) |
| 97 | + |
| 98 | + def query(self, X, k): |
| 99 | + """Carry out a batch query for k-NN of query set X.""" |
| 100 | + nq, d = X.shape |
| 101 | + Xq = self.quant(X, self.mi, self.mx) |
| 102 | + k_mul = 5 |
| 103 | + k_reorder = k * k_mul |
| 104 | + I, _ = self.index.batch_search( |
| 105 | + Xq, k_reorder, self.Ls, self.search_threads) |
| 106 | + I = I - 1 |
| 107 | + self.res = self.refiner.batch_refine(X, I, k).reshape(nq, k) |
| 108 | + |
| 109 | + def set_query_arguments(self, query_args): |
| 110 | + self._query_args = query_args |
| 111 | + self.Ls = 0 if query_args.get("Ls") == None else query_args.get("Ls") |
| 112 | + self.search_threads = self._query_args.get("T") |
| 113 | + |
| 114 | + def __str__(self): |
| 115 | + return f'pyanns({self.index_name(), self._query_args})' |
0 commit comments