Skip to content

Commit 9858c62

Browse files
authored
Add support for randomly sampling the input data to compute centroids with (#183)
1 parent e47bd83 commit 9858c62

File tree

4 files changed

+341
-7
lines changed

4 files changed

+341
-7
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 218 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
import json
22
from functools import partial
33
from typing import Any, Mapping, Optional, Tuple
4-
4+
import enum
5+
import random
56
import numpy as np
6-
from tiledb.cloud.dag import Mode
77

8+
from tiledb.cloud.dag import Mode
89
from tiledb.vector_search._tiledbvspy import *
910
from tiledb.vector_search.storage_formats import STORAGE_VERSION, validate_storage_version
1011

1112

13+
class TrainingSamplingPolicy(enum.Enum):
14+
FIRST_N = 1
15+
RANDOM = 2
16+
17+
def __str__(self):
18+
return self.name.replace("_", " ").title()
19+
1220
def ingest(
1321
index_type: str,
1422
index_uri: str,
@@ -25,6 +33,7 @@ def ingest(
2533
namespace: Optional[str] = None,
2634
size: int = -1,
2735
partitions: int = -1,
36+
training_sampling_policy: TrainingSamplingPolicy = TrainingSamplingPolicy.FIRST_N,
2837
copy_centroids_uri: str = None,
2938
training_sample_size: int = -1,
3039
training_input_vectors: np.ndarray = None,
@@ -33,6 +42,8 @@ def ingest(
3342
workers: int = -1,
3443
input_vectors_per_work_item: int = -1,
3544
max_tasks_per_stage: int= -1,
45+
input_vectors_per_work_item_during_sampling: int = -1,
46+
max_sampling_tasks: int= -1,
3647
storage_version: str = STORAGE_VERSION,
3748
verbose: bool = False,
3849
trace_id: Optional[str] = None,
@@ -102,6 +113,14 @@ def ingest(
102113
max_tasks_per_stage: int = -1
103114
Max number of tasks per execution stage of ingestion,
104115
if not provided, is auto-configured
116+
input_vectors_per_work_item_during_sampling: int = -1
117+
number of vectors per sample ingestion work item,
118+
if not provided, is auto-configured
119+
only valid with training_sampling_policy=TrainingSamplingPolicy.RANDOM
120+
max_sampling_tasks: int = -1
121+
Max number of tasks per execution stage of sampling,
122+
if not provided, is auto-configured
123+
only valid with training_sampling_policy=TrainingSamplingPolicy.RANDOM
105124
storage_version: str
106125
Vector index storage format version. If not provided, defaults to the latest version.
107126
verbose: bool
@@ -169,6 +188,10 @@ def ingest(
169188
for variable in ["copy_centroids_uri", "training_input_vectors", "training_source_uri", "training_source_type"]:
170189
if index_type != "IVF_FLAT" and locals().get(variable) is not None:
171190
raise ValueError(f"{variable} should only be provided with index_type IVF_FLAT")
191+
192+
for variable in ["copy_centroids_uri", "training_input_vectors", "training_source_uri", "training_source_type"]:
193+
if training_sampling_policy != TrainingSamplingPolicy.FIRST_N and locals().get(variable) is not None:
194+
raise ValueError(f"{variable} should not provided alonside training_sampling_policy")
172195

173196
# use index_group_uri for internal clarity
174197
index_group_uri = index_uri
@@ -191,6 +214,7 @@ def ingest(
191214
]
192215
DEFAULT_ATTR_FILTERS = storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"]
193216
VECTORS_PER_WORK_ITEM = 20000000
217+
VECTORS_PER_SAMPLE_WORK_ITEM=1000000
194218
MAX_TASKS_PER_STAGE = 100
195219
CENTRALISED_KMEANS_MAX_SAMPLE_SIZE = 1000000
196220
DEFAULT_IMG_NAME = "3.9-vectorsearch"
@@ -302,9 +326,8 @@ def read_source_metadata(
302326
else:
303327
raise ValueError(f"Not supported source_type {source_type} - valid types are [TILEDB_ARRAY, U8BIN, F32BIN, FVEC, IVEC, BVEC]")
304328

305-
def write_input_vectors(
329+
def create_array(
306330
group: tiledb.Group,
307-
input_vectors: np.ndarray,
308331
size: int,
309332
dimensions: int,
310333
vector_type: np.dtype,
@@ -350,6 +373,18 @@ def write_input_vectors(
350373
tiledb.Array.create(input_vectors_array_uri, input_vectors_array_schema)
351374
group.add(input_vectors_array_uri, name=array_name)
352375

376+
return input_vectors_array_uri
377+
378+
def write_input_vectors(
379+
group: tiledb.Group,
380+
input_vectors: np.ndarray,
381+
size: int,
382+
dimensions: int,
383+
vector_type: np.dtype,
384+
array_name: str
385+
) -> str:
386+
input_vectors_array_uri = create_array(group=group, size=size, dimensions=dimensions, vector_type=vector_type, array_name=array_name)
387+
353388
input_vectors_array = tiledb.open(
354389
input_vectors_array_uri, "w", timestamp=index_timestamp
355390
)
@@ -797,6 +832,91 @@ def copy_centroids(
797832
# --------------------------------------------------------------------
798833
# centralised kmeans UDFs
799834
# --------------------------------------------------------------------
835+
def random_sample_from_input_vectors(
836+
source_uri: str,
837+
source_type: str,
838+
vector_type: np.dtype,
839+
dimensions: int,
840+
source_start_pos: int,
841+
source_end_pos: int,
842+
batch: int,
843+
random_sample_size: int,
844+
output_source_uri: str,
845+
output_start_pos: int,
846+
config: Optional[Mapping[str, Any]] = None,
847+
verbose: bool = False,
848+
):
849+
'''
850+
Reads a random sample of vectors from the source data and appends them to the output array.
851+
852+
Parameters
853+
----------
854+
source_uri: str
855+
Data source URI.
856+
source_type: str
857+
Type of the source data.
858+
vector_type: np.dtype
859+
Type of the vectors.
860+
dimensions: int
861+
Number of dimensions in a vector.
862+
vector_start_pos: int
863+
Start position of source_uri to read from.
864+
vector_end_pos: int
865+
End position of source_uri to read to.
866+
batch: int
867+
Read the source data in batches of this size.
868+
random_sample_size: int
869+
Number of vectors to randomly sample from the source data.
870+
output_source_uri: str
871+
URI of the output array.
872+
output_start_pos: int
873+
Start position of the output array to write to.
874+
'''
875+
if random_sample_size == 0:
876+
return
877+
878+
with tiledb.scope_ctx(ctx_or_config=config):
879+
source_size = source_end_pos - source_start_pos
880+
num_sampled = 0
881+
for start in range(source_start_pos, source_end_pos, batch):
882+
# What vectors to read from the source_uri.
883+
end = start + batch
884+
if end > source_end_pos:
885+
end = source_end_pos
886+
887+
# How many vectors sample from the vectors read.
888+
percent_of_data_to_read = (end - start) / source_size
889+
num_to_sample = math.ceil(random_sample_size * percent_of_data_to_read)
890+
if num_sampled + num_to_sample > random_sample_size:
891+
num_to_sample = random_sample_size - num_sampled
892+
if num_to_sample == 0:
893+
continue
894+
num_sampled += num_to_sample
895+
896+
# Read from the source data.
897+
vectors = read_input_vectors(
898+
source_uri=source_uri,
899+
source_type=source_type,
900+
vector_type=vector_type,
901+
dimensions=dimensions,
902+
start_pos=start,
903+
end_pos=end,
904+
config=config,
905+
verbose=verbose,
906+
trace_id=trace_id,
907+
)
908+
909+
# Randomly sample from the data we read.
910+
row_indices = np.random.choice(vectors.shape[0], size=num_to_sample, replace=False)
911+
sampled_vectors = vectors[row_indices]
912+
913+
# Append to output array.
914+
with tiledb.open(output_source_uri, mode="w", timestamp=index_timestamp) as A:
915+
A[0:dimensions, output_start_pos:output_start_pos + num_to_sample] = np.transpose(sampled_vectors)
916+
917+
if num_sampled != random_sample_size:
918+
raise ValueError(f"The random sampling within a batch ran into an issue: num_sampled ({num_sampled}) != random_sample_size ({random_sample_size})")
919+
800920
def centralised_kmeans(
801921
index_group_uri: str,
802922
source_uri: str,
@@ -1564,6 +1684,8 @@ def create_ingestion_dag(
15641684
training_source_type: Optional[str],
15651685
input_vectors_per_work_item: int,
15661686
input_vectors_work_items_per_worker: int,
1687+
input_vectors_per_work_item_during_sampling: int,
1688+
input_vectors_work_items_per_worker_during_sampling: int,
15671689
table_partitions_per_work_item: int,
15681690
table_partitions_work_items_per_worker: int,
15691691
workers: int,
@@ -1600,6 +1722,15 @@ def create_ingestion_dag(
16001722
input_vectors_batch_size = (
16011723
input_vectors_per_work_item * input_vectors_work_items_per_worker
16021724
)
1725+
1726+
# The number of vectors each task will read.
1727+
input_vectors_batch_size_during_sampling = (
1728+
# The number of vectors to read into memory in one batch within a task.
1729+
input_vectors_per_work_item_during_sampling *
1730+
# The number of batches that a single task will need to run.
1731+
input_vectors_work_items_per_worker_during_sampling
1732+
)
1733+
16031734
if index_type == "FLAT":
16041735
ingest_node = submit(
16051736
ingest_flat,
@@ -1637,6 +1768,59 @@ def create_ingestion_dag(
16371768
image_name=DEFAULT_IMG_NAME,
16381769
)
16391770
else:
1771+
random_sample_nodes = []
1772+
if training_sampling_policy == TrainingSamplingPolicy.RANDOM:
1773+
# Create an empty array to write the sampled vectors to.
1774+
group = tiledb.Group(index_group_uri, "w")
1775+
training_source_uri = create_array(
1776+
group=group,
1777+
size=training_sample_size,
1778+
dimensions=dimensions,
1779+
vector_type=vector_type,
1780+
array_name=TRAINING_INPUT_VECTORS_ARRAY_NAME
1781+
)
1782+
training_source_type = "TILEDB_ARRAY"
1783+
group.close()
1784+
1785+
idx = 0
1786+
num_sampled = 0
1787+
for start in range(0, in_size, input_vectors_batch_size_during_sampling):
1788+
# What vectors to read from the source_uri.
1789+
end = start + input_vectors_batch_size_during_sampling
1790+
if end > size:
1791+
end = size
1792+
1793+
# How many vectors to sample from the vectors read.
1794+
percent_of_data_to_read = (end - start) / in_size
1795+
num_to_sample = math.ceil(training_sample_size * percent_of_data_to_read)
1796+
if num_sampled + num_to_sample > training_sample_size:
1797+
num_to_sample = training_sample_size - num_sampled
1798+
if num_to_sample == 0:
1799+
continue
1800+
1801+
random_sample_nodes.append(submit(
1802+
random_sample_from_input_vectors,
1803+
source_uri=source_uri,
1804+
source_type=source_type,
1805+
vector_type=vector_type,
1806+
dimensions=dimensions,
1807+
source_start_pos=start,
1808+
source_end_pos=end,
1809+
batch=input_vectors_per_work_item_during_sampling,
1810+
random_sample_size=num_to_sample,
1811+
output_source_uri=training_source_uri,
1812+
output_start_pos=num_sampled,
1813+
config=config,
1814+
verbose=verbose,
1815+
name="read-random-sample-" + str(idx),
1816+
resources={"cpu": str(threads), "memory": "1Gi"},
1817+
image_name=DEFAULT_IMG_NAME,
1818+
))
1819+
num_sampled += num_to_sample
1820+
idx += 1
1821+
if num_sampled != training_sample_size:
1822+
raise ValueError(f"The random sampling ran into an issue: num_sampled ({num_sampled}) != training_sample_size ({training_sample_size})")
1823+
16401824
if training_sample_size <= CENTRALISED_KMEANS_MAX_SAMPLE_SIZE:
16411825
centroids_node = submit(
16421826
centralised_kmeans,
@@ -1657,6 +1841,9 @@ def create_ingestion_dag(
16571841
resources={"cpu": "8", "memory": "32Gi"},
16581842
image_name=DEFAULT_IMG_NAME,
16591843
)
1844+
1845+
for random_sample_node in random_sample_nodes:
1846+
centroids_node.depends_on(random_sample_node)
16601847
else:
16611848
internal_centroids_node = submit(
16621849
init_centroids,
@@ -1950,6 +2137,9 @@ def consolidate_and_vacuum(
19502137
logger.debug("Input dataset size %d", size)
19512138
logger.debug("Input dataset dimensions %d", dimensions)
19522139
logger.debug("Vector dimension type %s", vector_type)
2140+
if training_sample_size > in_size:
2141+
raise ValueError(f"training_sample_size {training_sample_size} is larger than the input dataset size {in_size}")
2142+
19532143
if partitions == -1:
19542144
partitions = max(1, int(math.sqrt(size)))
19552145
if training_sample_size == -1:
@@ -1976,6 +2166,7 @@ def consolidate_and_vacuum(
19762166
if external_ids_type is None:
19772167
external_ids_type = "U64BIN"
19782168

2169+
# Compute task parameters for main ingestion.
19792170
if input_vectors_per_work_item == -1:
19802171
input_vectors_per_work_item = VECTORS_PER_WORK_ITEM
19812172
input_vectors_work_items = int(math.ceil(size / input_vectors_per_work_item))
@@ -1996,6 +2187,27 @@ def consolidate_and_vacuum(
19962187
input_vectors_work_items_per_worker,
19972188
)
19982189

2190+
# Compute task parameters for random sampling.
2191+
# How many input vectors to read into memory in one batch within a task.
2192+
if input_vectors_per_work_item_during_sampling == -1:
2193+
input_vectors_per_work_item_during_sampling = VECTORS_PER_SAMPLE_WORK_ITEM
2194+
# How many total batches we need to read all the data..
2195+
input_vectors_work_items_during_sampling = int(math.ceil(size / input_vectors_per_work_item_during_sampling))
2196+
# The number of tasks to create, at max.
2197+
if max_sampling_tasks == -1:
2198+
max_sampling_tasks = MAX_TASKS_PER_STAGE
2199+
# The number of batches a single task will run. If there are more batches required than
2200+
# allowed tasks, each task will process mutiple batches.
2201+
input_vectors_work_items_per_worker_during_sampling = 1
2202+
if input_vectors_work_items_during_sampling > max_sampling_tasks:
2203+
input_vectors_work_items_per_worker_during_sampling = int(
2204+
math.ceil(input_vectors_work_items_during_sampling / max_sampling_tasks)
2205+
)
2206+
input_vectors_work_items_during_sampling = max_sampling_tasks
2207+
logger.debug("input_vectors_per_work_item_during_sampling %d", input_vectors_per_work_item_during_sampling)
2208+
logger.debug("input_vectors_work_items_during_sampling %d", input_vectors_work_items_during_sampling)
2209+
logger.debug("input_vectors_work_items_per_worker_during_sampling %d", input_vectors_work_items_per_worker_during_sampling)
2210+
19992211
vectors_per_table_partitions = max(1, size / partitions)
20002212
table_partitions_per_work_item = max(
20012213
1,
@@ -2054,6 +2266,8 @@ def consolidate_and_vacuum(
20542266
training_source_type=training_source_type,
20552267
input_vectors_per_work_item=input_vectors_per_work_item,
20562268
input_vectors_work_items_per_worker=input_vectors_work_items_per_worker,
2269+
input_vectors_per_work_item_during_sampling=input_vectors_per_work_item_during_sampling,
2270+
input_vectors_work_items_per_worker_during_sampling=input_vectors_work_items_per_worker_during_sampling,
20572271
table_partitions_per_work_item=table_partitions_per_work_item,
20582272
table_partitions_work_items_per_worker=table_partitions_work_items_per_worker,
20592273
workers=workers,

apis/python/test/common.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
import tiledb
8-
8+
from tiledb.vector_search.storage_formats import storage_formats, STORAGE_VERSION
99

1010
def xbin_mmap(fname, dtype):
1111
n, d = map(int, np.fromfile(fname, dtype="uint32", count=2))
@@ -302,3 +302,11 @@ def check_equals(result_d, result_i, expected_result_d, expected_result_i):
302302
def random_name(name: str) -> str:
303303
suffix = "".join(random.choices(string.ascii_letters, k=10))
304304
return f"zzz_unittest_{name}_{suffix}"
305+
306+
def check_training_input_vectors(index_uri: str, expected_training_sample_size: int, expected_dimensions: int):
307+
training_input_vectors_uri = f"{index_uri}/{storage_formats[STORAGE_VERSION]['TRAINING_INPUT_VECTORS_ARRAY_NAME']}"
308+
with tiledb.open(training_input_vectors_uri, mode="r") as src_array:
309+
training_input_vectors = np.transpose(src_array[:, :]["values"])
310+
assert training_input_vectors.shape[0] == expected_training_sample_size
311+
assert training_input_vectors.shape[1] == expected_dimensions
312+
assert not np.isnan(training_input_vectors).any()

0 commit comments

Comments
 (0)