Skip to content

Commit f5734f9

Browse files
authored
Fix training_data bug in test_ingest_with_training_source_uri_tdb and validate training data dimensions in ingest() (#175)
1 parent bb20254 commit f5734f9

File tree

3 files changed

+68
-9
lines changed

3 files changed

+68
-9
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,13 @@ def ingest(
157157
raise ValueError("training_source_type should not be provided without training_source_uri")
158158

159159
if training_sample_size < -1:
160-
raise ValueError("training_sample_size should either be positive or -1 to auto-configure based on the dataset sizes")
160+
raise ValueError("training_sample_size should either be positive or -1 (to auto-configure based on the dataset sizes)")
161+
162+
if index_type != "IVF_FLAT" and training_sample_size != -1:
163+
raise ValueError("training_sample_size should only be provided with index_type IVF_FLAT")
164+
for variable in ["copy_centroids_uri", "training_input_vectors", "training_source_uri", "training_source_type"]:
165+
if index_type != "IVF_FLAT" and locals().get(variable) is not None:
166+
raise ValueError(f"{variable} should only be provided with index_type IVF_FLAT")
161167

162168
# use index_group_uri for internal clarity
163169
index_group_uri = index_uri
@@ -820,12 +826,13 @@ def centralised_kmeans(
820826
if training_source_type is None:
821827
training_source_type = autodetect_source_type(source_uri=training_source_uri)
822828
training_in_size, training_dimensions, training_vector_type = read_source_metadata(source_uri=training_source_uri, source_type=training_source_type)
823-
dimensions = training_dimensions
829+
if dimensions != training_dimensions:
830+
raise ValueError(f"When training centroids, the index data dimensions ({dimensions}) != the training data dimensions ({training_dimensions})")
824831
sample_vectors = read_input_vectors(
825832
source_uri=training_source_uri,
826833
source_type=training_source_type,
827834
vector_type=training_vector_type,
828-
dimensions=training_dimensions,
835+
dimensions=dimensions,
829836
start_pos=0,
830837
end_pos=training_in_size,
831838
config=config,
@@ -2076,4 +2083,4 @@ def consolidate_and_vacuum(
20762083
elif index_type == "IVF_FLAT":
20772084
return ivf_flat_index.IVFFlatIndex(
20782085
uri=index_group_uri, memory_budget=1000000, config=config
2079-
)
2086+
)

apis/python/test/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,8 @@ def check_equals(result_d, result_i, expected_result_d, expected_result_i):
295295
result_i_expected: int
296296
The expected indices
297297
"""
298-
assert result_i == expected_result_i
299-
assert result_d == expected_result_d
298+
assert result_i == expected_result_i, f"result_i: {result_i} != expected_result_i: {expected_result_i}"
299+
assert result_d == expected_result_d, f"result_d: {result_d} != expected_result_d: {expected_result_d}"
300300

301301
# Generate random names for test array uris
302302
def random_name(name: str) -> str:

apis/python/test/test_ingestion.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -886,15 +886,43 @@ def test_ingest_with_training_source_uri_f32(tmp_path):
886886
)
887887

888888
def test_ingest_with_training_source_uri_tdb(tmp_path):
889+
################################################################################################
890+
# First set up the data.
891+
################################################################################################
889892
dataset_dir = os.path.join(tmp_path, "dataset")
890893
os.mkdir(dataset_dir)
891894
# data.shape should give you (cols, rows). So we transpose this before using it.
892-
data = np.array([[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3], [3.0, 3.1, 3.2, 3.3], [4.0, 4.1, 4.2, 4.3], [5.0, 5.1, 5.2, 5.3]], dtype=np.float32).transpose()
895+
data = np.array([
896+
[1.0, 1.1, 1.2, 1.3],
897+
[2.0, 2.1, 2.2, 2.3],
898+
[3.0, 3.1, 3.2, 3.3],
899+
[4.0, 4.1, 4.2, 4.3],
900+
[5.0, 5.1, 5.2, 5.3]], dtype=np.float32).transpose()
893901
create_array(path=os.path.join(dataset_dir, "data.tdb"), data=data)
894902

895-
training_data = data[1:3]
903+
training_data = np.array([
904+
[1.0, 1.1, 1.2, 1.3],
905+
[5.0, 5.1, 5.2, 5.3]], dtype=np.float32).transpose()
896906
create_array(path=os.path.join(dataset_dir, "training_data.tdb"), data=training_data)
897907

908+
# Run a quick test that if we set up training_data incorrectly, we will raise an exception.
909+
with pytest.raises(ValueError) as error:
910+
training_data_invalid = np.array([
911+
[1.0, 1.1, 1.2],
912+
[5.0, 5.1, 5.2]], dtype=np.float32).transpose()
913+
create_array(path=os.path.join(dataset_dir, "training_data_invalid.tdb"), data=training_data_invalid)
914+
index = ingest(
915+
index_type="IVF_FLAT",
916+
index_uri=os.path.join(tmp_path, f"array_invalid"),
917+
source_uri=os.path.join(dataset_dir, "data.tdb"),
918+
training_source_uri=os.path.join(dataset_dir, "training_data_invalid.tdb")
919+
)
920+
assert "training data dimensions" in str(error.value)
921+
922+
################################################################################################
923+
# Test we can ingest, query, update, and consolidate with a training_source_uri.
924+
################################################################################################
925+
print('[test_ingestion@test_ingest_with_training_source_uri_tdb] ingest() ======================================')
898926
index_uri = os.path.join(tmp_path, "array")
899927
index = ingest(
900928
index_type="IVF_FLAT",
@@ -922,9 +950,33 @@ def test_ingest_with_training_source_uri_tdb(tmp_path):
922950
)
923951

924952
def test_ingest_with_training_source_uri_numpy(tmp_path):
925-
data = np.array([[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3], [3.0, 3.1, 3.2, 3.3], [4.0, 4.1, 4.2, 4.3], [5.0, 5.1, 5.2, 5.3]], dtype=np.float32)
953+
################################################################################################
954+
# First set up the data.
955+
################################################################################################
956+
data = np.array([
957+
[1.0, 1.1, 1.2, 1.3],
958+
[2.0, 2.1, 2.2, 2.3],
959+
[3.0, 3.1, 3.2, 3.3],
960+
[4.0, 4.1, 4.2, 4.3],
961+
[5.0, 5.1, 5.2, 5.3]], dtype=np.float32)
926962
training_data = data[1:3]
927963

964+
# Run a quick test that if we set up training_data incorrectly, we will raise an exception.
965+
with pytest.raises(ValueError) as error:
966+
training_data_invalid = np.array([
967+
[4.0, 4.1, 4.2],
968+
[5.0, 5.1, 5.2]], dtype=np.float32)
969+
index = ingest(
970+
index_type="IVF_FLAT",
971+
index_uri=os.path.join(tmp_path, "array_invalid"),
972+
input_vectors=data,
973+
training_input_vectors=training_data_invalid,
974+
)
975+
assert "training data dimensions" in str(error.value)
976+
977+
################################################################################################
978+
# Test we can ingest, query, update, and consolidate.
979+
################################################################################################
928980
index_uri = os.path.join(tmp_path, "array")
929981
index = ingest(
930982
index_type="IVF_FLAT",

0 commit comments

Comments
 (0)