Skip to content

Commit 3e5489c

Browse files
authored
Let user specify a set of vectors to use for training in IVF_FLAT (#169)
1 parent 60fc836 commit 3e5489c

File tree

5 files changed

+247
-41
lines changed

5 files changed

+247
-41
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 103 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def ingest(
2727
partitions: int = -1,
2828
copy_centroids_uri: str = None,
2929
training_sample_size: int = -1,
30+
training_input_vectors: np.ndarray = None,
31+
training_source_uri: str = None,
32+
training_source_type: str = None,
3033
workers: int = -1,
3134
input_vectors_per_work_item: int = -1,
3235
max_tasks_per_stage: int= -1,
@@ -78,6 +81,18 @@ def ingest(
7881
training_sample_size: int = -1
7982
vector sample size to train centroids with,
8083
if not provided, is auto-configured based on the dataset sizes
84+
should not be provided if training_source_uri is provided
85+
training_input_vectors: numpy Array
86+
Training input vectors, if this is provided it takes precedence over training_source_uri and training_source_type
87+
should not be provided if training_sample_size or training_source_uri is provided
88+
training_source_uri: str = None
89+
The source URI to use for training centroids when building a IVF_FLAT vector index,
90+
if not provided, the first training_sample_size vectors from source_uri are used
91+
should not be provided if training_sample_size or training_input_vectors is provided
92+
training_source_type: str = None
93+
Type of the training source data in training_source_uri
94+
if left empty, is auto-detected from the suffix of training_source_type
95+
should only be provided when training_source_uri is provided
8196
workers: int = -1
8297
number of workers for vector ingestion,
8398
if not provided, is auto-configured based on the dataset size
@@ -121,6 +136,29 @@ def ingest(
121136

122137
validate_storage_version(storage_version)
123138

139+
if source_type and not source_uri:
140+
raise ValueError("source_type should not be provided without source_uri")
141+
if source_uri and input_vectors:
142+
raise ValueError("source_uri should not be provided alongside input_vectors")
143+
if source_type and input_vectors:
144+
raise ValueError("source_type should not be provided alongside input_vectors")
145+
146+
if training_source_uri and training_sample_size != -1:
147+
raise ValueError("training_source_uri and training_sample_size should not both be provided")
148+
if training_source_uri and training_input_vectors is not None:
149+
raise ValueError("training_source_uri and training_input_vectors should not both be provided")
150+
151+
if training_input_vectors is not None and training_sample_size != -1:
152+
raise ValueError("training_input_vectors and training_sample_size should not both be provided")
153+
if training_input_vectors is not None and training_source_type:
154+
raise ValueError("training_input_vectors and training_source_type should not both be provided")
155+
156+
if training_source_type and not training_source_uri:
157+
raise ValueError("training_source_type should not be provided without training_source_uri")
158+
159+
if training_sample_size < -1:
160+
raise ValueError("training_sample_size should either be positive or -1 to auto-configure based on the dataset sizes")
161+
124162
# use index_group_uri for internal clarity
125163
index_group_uri = index_uri
126164

@@ -131,6 +169,9 @@ def ingest(
131169
INPUT_VECTORS_ARRAY_NAME = storage_formats[storage_version][
132170
"INPUT_VECTORS_ARRAY_NAME"
133171
]
172+
TRAINING_INPUT_VECTORS_ARRAY_NAME = storage_formats[storage_version][
173+
"TRAINING_INPUT_VECTORS_ARRAY_NAME"
174+
]
134175
EXTERNAL_IDS_ARRAY_NAME = storage_formats[storage_version][
135176
"EXTERNAL_IDS_ARRAY_NAME"
136177
]
@@ -248,16 +289,17 @@ def read_source_metadata(
248289
size = int(file_size / vector_size)
249290
return size, dimensions, np.uint8
250291
else:
251-
raise ValueError(f"Not supported source_type {source_type}")
292+
raise ValueError(f"Not supported source_type {source_type} - valid types are [TILEDB_ARRAY, U8BIN, F32BIN, FVEC, IVEC, BVEC]")
252293

253294
def write_input_vectors(
254295
group: tiledb.Group,
255296
input_vectors: np.ndarray,
256297
size: int,
257298
dimensions: int,
258299
vector_type: np.dtype,
300+
array_name: str
259301
) -> str:
260-
input_vectors_array_uri = f"{group.uri}/{INPUT_VECTORS_ARRAY_NAME}"
302+
input_vectors_array_uri = f"{group.uri}/{array_name}"
261303
if tiledb.array_exists(input_vectors_array_uri):
262304
raise ValueError(f"Array exists {input_vectors_array_uri}")
263305
tile_size = min(
@@ -295,7 +337,7 @@ def write_input_vectors(
295337
)
296338
logger.debug(input_vectors_array_schema)
297339
tiledb.Array.create(input_vectors_array_uri, input_vectors_array_schema)
298-
group.add(input_vectors_array_uri, name=INPUT_VECTORS_ARRAY_NAME)
340+
group.add(input_vectors_array_uri, name=array_name)
299341

300342
input_vectors_array = tiledb.open(
301343
input_vectors_array_uri, "w", timestamp=index_timestamp
@@ -749,8 +791,9 @@ def centralised_kmeans(
749791
vector_type: np.dtype,
750792
partitions: int,
751793
dimensions: int,
752-
sample_start_pos: int,
753-
sample_end_pos: int,
794+
training_sample_size: int,
795+
training_source_uri: Optional[str],
796+
training_source_type: Optional[str],
754797
init: str = "random",
755798
max_iter: int = 10,
756799
n_init: int = 1,
@@ -765,45 +808,61 @@ def centralised_kmeans(
765808
array_to_matrix,
766809
kmeans_fit,
767810
)
811+
768812
with tiledb.scope_ctx(ctx_or_config=config):
769813
logger = setup(config, verbose)
770814
group = tiledb.Group(index_group_uri)
771815
centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
772-
verb = 0
773-
if verbose:
774-
verb = 3
775-
776-
if sample_end_pos - sample_start_pos >= partitions:
777-
sample_vectors = read_input_vectors(
778-
source_uri=source_uri,
779-
source_type=source_type,
780-
vector_type=vector_type,
781-
dimensions=dimensions,
782-
start_pos=sample_start_pos,
783-
end_pos=sample_end_pos,
784-
config=config,
785-
verbose=verbose,
786-
trace_id=trace_id,
787-
).astype(np.float32)
816+
if training_sample_size >= partitions:
817+
if training_source_uri:
818+
if training_source_type is None:
819+
training_source_type = autodetect_source_type(source_uri=training_source_uri)
820+
training_in_size, training_dimensions, training_vector_type = read_source_metadata(source_uri=training_source_uri, source_type=training_source_type)
821+
dimensions = training_dimensions
822+
sample_vectors = read_input_vectors(
823+
source_uri=training_source_uri,
824+
source_type=training_source_type,
825+
vector_type=training_vector_type,
826+
dimensions=training_dimensions,
827+
start_pos=0,
828+
end_pos=training_in_size,
829+
config=config,
830+
verbose=verbose,
831+
trace_id=trace_id,
832+
).astype(np.float32)
833+
else:
834+
sample_vectors = read_input_vectors(
835+
source_uri=source_uri,
836+
source_type=source_type,
837+
vector_type=vector_type,
838+
dimensions=dimensions,
839+
start_pos=0,
840+
end_pos=training_sample_size,
841+
config=config,
842+
verbose=verbose,
843+
trace_id=trace_id,
844+
).astype(np.float32)
788845

846+
logger.debug("Start kmeans training")
789847
if use_sklearn:
790848
km = KMeans(
791849
n_clusters=partitions,
792850
init=init,
793851
max_iter=max_iter,
794-
verbose=verb,
852+
verbose=3 if verbose else 0,
795853
n_init=n_init,
854+
random_state=0,
796855
)
797856
km.fit_predict(sample_vectors)
798857
centroids = np.transpose(np.array(km.cluster_centers_))
799858
else:
800859
centroids = kmeans_fit(partitions, init, max_iter, verbose, n_init, array_to_matrix(np.transpose(sample_vectors)))
801860
centroids = np.array(centroids) # TODO: why is this here?
802861
else:
862+
# TODO(paris): Should we instead take the first training_sample_size vectors and then fill in random for the rest? Or raise an error like this:
863+
# raise ValueError(f"We have a training_sample_size of {training_sample_size} but {partitions} partitions - training_sample_size must be >= partitions")
803864
centroids = np.random.rand(dimensions, partitions)
804865

805-
logger.debug("Start kmeans training")
806-
807866
logger.debug("Writing centroids to array %s", centroids_uri)
808867
with tiledb.open(centroids_uri, mode="w", timestamp=index_timestamp) as A:
809868
A[0:dimensions, 0:partitions] = centroids
@@ -1487,6 +1546,8 @@ def create_ingestion_dag(
14871546
dimensions: int,
14881547
copy_centroids_uri: str,
14891548
training_sample_size: int,
1549+
training_source_uri: Optional[str],
1550+
training_source_type: Optional[str],
14901551
input_vectors_per_work_item: int,
14911552
input_vectors_work_items_per_worker: int,
14921553
table_partitions_per_work_item: int,
@@ -1569,8 +1630,9 @@ def create_ingestion_dag(
15691630
vector_type=vector_type,
15701631
partitions=partitions,
15711632
dimensions=dimensions,
1572-
sample_start_pos=0,
1573-
sample_end_pos=training_sample_size,
1633+
training_sample_size=training_sample_size,
1634+
training_source_uri=training_source_uri,
1635+
training_source_type=training_source_type,
15741636
config=config,
15751637
verbose=verbose,
15761638
trace_id=trace_id,
@@ -1835,6 +1897,17 @@ def consolidate_and_vacuum(
18351897
group.close()
18361898
group = tiledb.Group(index_group_uri, "w")
18371899

1900+
if training_input_vectors is not None:
1901+
training_source_uri = write_input_vectors(
1902+
group=group,
1903+
input_vectors=training_input_vectors,
1904+
size=training_input_vectors.shape[0],
1905+
dimensions=training_input_vectors.shape[1],
1906+
vector_type=training_input_vectors.dtype,
1907+
array_name=TRAINING_INPUT_VECTORS_ARRAY_NAME
1908+
)
1909+
training_source_type = "TILEDB_ARRAY"
1910+
18381911
if input_vectors is not None:
18391912
in_size = input_vectors.shape[0]
18401913
dimensions = input_vectors.shape[1]
@@ -1845,6 +1918,7 @@ def consolidate_and_vacuum(
18451918
size=in_size,
18461919
dimensions=dimensions,
18471920
vector_type=vector_type,
1921+
array_name=INPUT_VECTORS_ARRAY_NAME
18481922
)
18491923
source_type = "TILEDB_ARRAY"
18501924
else:
@@ -1871,6 +1945,7 @@ def consolidate_and_vacuum(
18711945
workers = 1
18721946
logger.debug("Partitions %d", partitions)
18731947
logger.debug("Training sample size %d", training_sample_size)
1948+
logger.debug("Training source uri %s and type %s", training_source_uri, training_source_type)
18741949
logger.debug("Number of workers %d", workers)
18751950

18761951
if external_ids is not None:
@@ -1959,6 +2034,8 @@ def consolidate_and_vacuum(
19592034
dimensions=dimensions,
19602035
copy_centroids_uri=copy_centroids_uri,
19612036
training_sample_size=training_sample_size,
2037+
training_source_uri=training_source_uri,
2038+
training_source_type=training_source_type,
19622039
input_vectors_per_work_item=input_vectors_per_work_item,
19632040
input_vectors_work_items_per_worker=input_vectors_work_items_per_worker,
19642041
table_partitions_per_work_item=table_partitions_per_work_item,

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ def create(
465465
group_exists=group_exists,
466466
config=config,
467467
)
468+
# TODO(paris): Save training_source_uri as metadata so that we use it for re-ingestion's.
468469
with tiledb.scope_ctx(ctx_or_config=config):
469470
group = tiledb.Group(uri, "w")
470471
tile_size = int(TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions)

apis/python/src/tiledb/vector_search/storage_formats.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"IDS_ARRAY_NAME": "ids.tdb",
88
"PARTS_ARRAY_NAME": "parts.tdb",
99
"INPUT_VECTORS_ARRAY_NAME": "input_vectors",
10+
"TRAINING_INPUT_VECTORS_ARRAY_NAME": "training_input_vectors",
1011
"EXTERNAL_IDS_ARRAY_NAME": "external_ids",
1112
"PARTIAL_WRITE_ARRAY_DIR": "write_temp",
1213
"DEFAULT_ATTR_FILTERS": None,
@@ -19,6 +20,7 @@
1920
"IDS_ARRAY_NAME": "shuffled_vector_ids",
2021
"PARTS_ARRAY_NAME": "shuffled_vectors",
2122
"INPUT_VECTORS_ARRAY_NAME": "input_vectors",
23+
"TRAINING_INPUT_VECTORS_ARRAY_NAME": "training_input_vectors",
2224
"EXTERNAL_IDS_ARRAY_NAME": "external_ids",
2325
"PARTIAL_WRITE_ARRAY_DIR": "temp_data",
2426
"DEFAULT_ATTR_FILTERS": tiledb.FilterList([tiledb.ZstdFilter()]),
@@ -31,6 +33,7 @@
3133
"IDS_ARRAY_NAME": "shuffled_vector_ids",
3234
"PARTS_ARRAY_NAME": "shuffled_vectors",
3335
"INPUT_VECTORS_ARRAY_NAME": "input_vectors",
36+
"TRAINING_INPUT_VECTORS_ARRAY_NAME": "training_input_vectors",
3437
"EXTERNAL_IDS_ARRAY_NAME": "external_ids",
3538
"PARTIAL_WRITE_ARRAY_DIR": "temp_data",
3639
"DEFAULT_ATTR_FILTERS": tiledb.FilterList([tiledb.ZstdFilter()]),
@@ -44,4 +47,4 @@
4447
def validate_storage_version(storage_version):
4548
if storage_version not in storage_formats:
4649
valid_versions = ', '.join(storage_formats.keys())
47-
raise ValueError(f"Invalid storage version: {storage_version}. Valid versions are: [{valid_versions}]")
50+
raise ValueError(f"Invalid storage version: {storage_version} - valid versions are [{valid_versions}]")

0 commit comments

Comments
 (0)