1+ import json
2+ import multiprocessing
3+ from typing import Any , Mapping
4+
5+ import numpy as np
6+ from tiledb .cloud .dag import Mode
7+
8+ from tiledb .vector_search import index
9+ from tiledb .vector_search .module import *
10+ from tiledb .vector_search .storage_formats import (STORAGE_VERSION ,
11+ storage_formats ,
12+ validate_storage_version )
13+ from tiledb .vector_search .utils import add_to_group
14+ from tiledb .vector_search import _tiledbvspy as vspy
15+
16+ MAX_UINT64 = np .iinfo (np .dtype ("uint64" )).max
17+ INDEX_TYPE = "VAMANA"
18+
19+ class VamanaIndex (index .Index ):
20+ """
21+ Open a Vamana index
22+
23+ Parameters
24+ ----------
25+ uri: str
26+ URI of the index
27+ config: Optional[Mapping[str, Any]]
28+ config dictionary, defaults to None
29+ """
30+
31+ def __init__ (
32+ self ,
33+ uri : str ,
34+ config : Optional [Mapping [str , Any ]] = None ,
35+ timestamp = None ,
36+ ** kwargs ,
37+ ):
38+ super ().__init__ (uri = uri , config = config , timestamp = timestamp )
39+ self .index_type = INDEX_TYPE
40+ self .index = vspy .IndexVamana (vspy .Ctx (config ), uri )
41+ self .db_uri = self .group [storage_formats [self .storage_version ]["PARTS_ARRAY_NAME" ]].uri
42+ self .ids_uri = self .group [storage_formats [self .storage_version ]["IDS_ARRAY_NAME" ]].uri
43+
44+ schema = tiledb .ArraySchema .load (self .db_uri , ctx = tiledb .Ctx (self .config ))
45+ self .dimensions = self .index .dimension ()
46+
47+ self .dtype = np .dtype (self .group .meta .get ("dtype" , None ))
48+ if self .dtype is None :
49+ self .dtype = np .dtype (schema .attr ("values" ).dtype )
50+ else :
51+ self .dtype = np .dtype (self .dtype )
52+
53+ if self .base_size == - 1 :
54+ self .size = schema .domain .dim (1 ).domain [1 ] + 1
55+ else :
56+ self .size = self .base_size
57+
58+ def get_dimensions (self ):
59+ return self .dimensions
60+
61+ def query_internal (
62+ self ,
63+ queries : np .ndarray ,
64+ k : int = 10 ,
65+ ):
66+ """
67+ Query an VAMANA index
68+
69+ Parameters
70+ ----------
71+ queries: numpy.ndarray
72+ ND Array of queries
73+ k: int
74+ Number of top results to return per query
75+ """
76+ if self .size == 0 :
77+ return np .full ((queries .shape [0 ], k ), index .MAX_FLOAT_32 ), np .full (
78+ (queries .shape [0 ], k ), index .MAX_UINT64
79+ )
80+
81+ assert queries .dtype == np .float32
82+
83+ if queries .ndim == 1 :
84+ queries = np .array ([queries ])
85+
86+ # TODO(paris): Actually run the query.
87+ return [], []
88+
89+ # TODO(paris): Pass more arguments to C++, i.e. storage_version.
90+ def create (
91+ uri : str ,
92+ dimensions : int ,
93+ vector_type : np .dtype ,
94+ id_type : np .dtype = np .uint32 ,
95+ adjacency_row_index_type : np .dtype = np .uint32 ,
96+ group_exists : bool = False ,
97+ config : Optional [Mapping [str , Any ]] = None ,
98+ storage_version : str = STORAGE_VERSION ,
99+ ** kwargs ,
100+ ) -> VamanaIndex :
101+ if not group_exists :
102+ ctx = vspy .Ctx (config )
103+ index = vspy .IndexVamana (
104+ feature_type = np .dtype (vector_type ).name ,
105+ id_type = np .dtype (id_type ).name ,
106+ adjacency_row_index_type = np .dtype (adjacency_row_index_type ).name ,
107+ dimension = dimensions ,
108+ )
109+ # TODO(paris): Run all of this with a single C++ call.
110+ empty_vector = vspy .FeatureVectorArray (
111+ dimensions ,
112+ 0 ,
113+ np .dtype (vector_type ).name ,
114+ np .dtype (id_type ).name
115+ )
116+ index .train (empty_vector )
117+ index .add (empty_vector )
118+ index .write_index (ctx , uri )
119+ return VamanaIndex (uri = uri , config = config , memory_budget = 1000000 )
0 commit comments