| 
 | 1 | +import numpy as np  | 
 | 2 | + | 
 | 3 | +from tiledb.vector_search.module import *  | 
 | 4 | +from tiledb.vector_search.storage_formats import storage_formats  | 
 | 5 | +from tiledb.vector_search.index import Index  | 
 | 6 | +from typing import Any, Mapping  | 
 | 7 | + | 
 | 8 | + | 
 | 9 | +class FlatIndex(Index):  | 
 | 10 | +    """  | 
 | 11 | +    Open a flat index  | 
 | 12 | +
  | 
 | 13 | +    Parameters  | 
 | 14 | +    ----------  | 
 | 15 | +    uri: str  | 
 | 16 | +        URI of the index  | 
 | 17 | +    config: Optional[Mapping[str, Any]]  | 
 | 18 | +        config dictionary, defaults to None  | 
 | 19 | +    """  | 
 | 20 | + | 
 | 21 | +    def __init__(  | 
 | 22 | +        self,  | 
 | 23 | +        uri: str,  | 
 | 24 | +        config: Optional[Mapping[str, Any]] = None,  | 
 | 25 | +    ):  | 
 | 26 | +        super().__init__(uri=uri, config=config)  | 
 | 27 | +        self.index_type = "FLAT"  | 
 | 28 | +        self._index = None  | 
 | 29 | +        self.db_uri = self.group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"] + self.index_version].uri  | 
 | 30 | +        schema = tiledb.ArraySchema.load(  | 
 | 31 | +            self.db_uri, ctx=tiledb.Ctx(self.config)  | 
 | 32 | +        )  | 
 | 33 | +        self.size = schema.domain.dim(1).domain[1]+1  | 
 | 34 | +        self._db = load_as_matrix(  | 
 | 35 | +            self.db_uri,  | 
 | 36 | +            ctx=self.ctx,  | 
 | 37 | +            config=config,  | 
 | 38 | +        )  | 
 | 39 | +        self.ids_uri = self.group[  | 
 | 40 | +            storage_formats[self.storage_version]["IDS_ARRAY_NAME"] + self.index_version  | 
 | 41 | +        ].uri  | 
 | 42 | +        if tiledb.array_exists(self.ids_uri, self.ctx):  | 
 | 43 | +            self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0)  | 
 | 44 | +        else:  | 
 | 45 | +            self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))  | 
 | 46 | + | 
 | 47 | +        dtype = self.group.meta.get("dtype", None)  | 
 | 48 | +        if dtype is None:  | 
 | 49 | +            self.dtype = self._db.dtype  | 
 | 50 | +        else:  | 
 | 51 | +            self.dtype = np.dtype(dtype)  | 
 | 52 | + | 
 | 53 | +    def query_internal(  | 
 | 54 | +        self,  | 
 | 55 | +        queries: np.ndarray,  | 
 | 56 | +        k: int = 10,  | 
 | 57 | +        nthreads: int = 8,  | 
 | 58 | +    ):  | 
 | 59 | +        """  | 
 | 60 | +        Query a flat index  | 
 | 61 | +
  | 
 | 62 | +        Parameters  | 
 | 63 | +        ----------  | 
 | 64 | +        queries: numpy.ndarray  | 
 | 65 | +            ND Array of queries  | 
 | 66 | +        k: int  | 
 | 67 | +            Number of top results to return per query  | 
 | 68 | +        nthreads: int  | 
 | 69 | +            Number of threads to use for query  | 
 | 70 | +        """  | 
 | 71 | +        # TODO:  | 
 | 72 | +        # - typecheck queries  | 
 | 73 | +        # - add all the options and query strategies  | 
 | 74 | + | 
 | 75 | +        assert queries.dtype == np.float32  | 
 | 76 | + | 
 | 77 | +        queries_m = array_to_matrix(np.transpose(queries))  | 
 | 78 | +        d, i = query_vq_heap(self._db, queries_m, self._ids, k, nthreads)  | 
 | 79 | + | 
 | 80 | +        return np.transpose(np.array(d)), np.transpose(np.array(i))  | 
0 commit comments