@@ -886,15 +886,43 @@ def test_ingest_with_training_source_uri_f32(tmp_path):
886886 )
887887
888888def test_ingest_with_training_source_uri_tdb (tmp_path ):
889+ ################################################################################################
890+ # First set up the data.
891+ ################################################################################################
889892 dataset_dir = os .path .join (tmp_path , "dataset" )
890893 os .mkdir (dataset_dir )
891894 # data.shape should give you (cols, rows). So we transpose this before using it.
892- data = np .array ([[1.0 , 1.1 , 1.2 , 1.3 ], [2.0 , 2.1 , 2.2 , 2.3 ], [3.0 , 3.1 , 3.2 , 3.3 ], [4.0 , 4.1 , 4.2 , 4.3 ], [5.0 , 5.1 , 5.2 , 5.3 ]], dtype = np .float32 ).transpose ()
895+ data = np .array ([
896+ [1.0 , 1.1 , 1.2 , 1.3 ],
897+ [2.0 , 2.1 , 2.2 , 2.3 ],
898+ [3.0 , 3.1 , 3.2 , 3.3 ],
899+ [4.0 , 4.1 , 4.2 , 4.3 ],
900+ [5.0 , 5.1 , 5.2 , 5.3 ]], dtype = np .float32 ).transpose ()
893901 create_array (path = os .path .join (dataset_dir , "data.tdb" ), data = data )
894902
895- training_data = data [1 :3 ]
903+ training_data = np .array ([
904+ [1.0 , 1.1 , 1.2 , 1.3 ],
905+ [5.0 , 5.1 , 5.2 , 5.3 ]], dtype = np .float32 ).transpose ()
896906 create_array (path = os .path .join (dataset_dir , "training_data.tdb" ), data = training_data )
897907
908+ # Run a quick test that if we set up training_data incorrectly, we will raise an exception.
909+ with pytest .raises (ValueError ) as error :
910+ training_data_invalid = np .array ([
911+ [1.0 , 1.1 , 1.2 ],
912+ [5.0 , 5.1 , 5.2 ]], dtype = np .float32 ).transpose ()
913+ create_array (path = os .path .join (dataset_dir , "training_data_invalid.tdb" ), data = training_data_invalid )
914+ index = ingest (
915+ index_type = "IVF_FLAT" ,
916+ index_uri = os .path .join (tmp_path , f"array_invalid" ),
917+ source_uri = os .path .join (dataset_dir , "data.tdb" ),
918+ training_source_uri = os .path .join (dataset_dir , "training_data_invalid.tdb" )
919+ )
920+ assert "training data dimensions" in str (error .value )
921+
922+ ################################################################################################
923+ # Test we can ingest, query, update, and consolidate with a training_source_uri.
924+ ################################################################################################
925+ print ('[test_ingestion@test_ingest_with_training_source_uri_tdb] ingest() ======================================' )
898926 index_uri = os .path .join (tmp_path , "array" )
899927 index = ingest (
900928 index_type = "IVF_FLAT" ,
@@ -922,9 +950,33 @@ def test_ingest_with_training_source_uri_tdb(tmp_path):
922950 )
923951
924952def test_ingest_with_training_source_uri_numpy (tmp_path ):
925- data = np .array ([[1.0 , 1.1 , 1.2 , 1.3 ], [2.0 , 2.1 , 2.2 , 2.3 ], [3.0 , 3.1 , 3.2 , 3.3 ], [4.0 , 4.1 , 4.2 , 4.3 ], [5.0 , 5.1 , 5.2 , 5.3 ]], dtype = np .float32 )
953+ ################################################################################################
954+ # First set up the data.
955+ ################################################################################################
956+ data = np .array ([
957+ [1.0 , 1.1 , 1.2 , 1.3 ],
958+ [2.0 , 2.1 , 2.2 , 2.3 ],
959+ [3.0 , 3.1 , 3.2 , 3.3 ],
960+ [4.0 , 4.1 , 4.2 , 4.3 ],
961+ [5.0 , 5.1 , 5.2 , 5.3 ]], dtype = np .float32 )
926962 training_data = data [1 :3 ]
927963
964+ # Run a quick test that if we set up training_data incorrectly, we will raise an exception.
965+ with pytest .raises (ValueError ) as error :
966+ training_data_invalid = np .array ([
967+ [4.0 , 4.1 , 4.2 ],
968+ [5.0 , 5.1 , 5.2 ]], dtype = np .float32 )
969+ index = ingest (
970+ index_type = "IVF_FLAT" ,
971+ index_uri = os .path .join (tmp_path , "array_invalid" ),
972+ input_vectors = data ,
973+ training_input_vectors = training_data_invalid ,
974+ )
975+ assert "training data dimensions" in str (error .value )
976+
977+ ################################################################################################
978+ # Test we can ingest, query, update, and consolidate.
979+ ################################################################################################
928980 index_uri = os .path .join (tmp_path , "array" )
929981 index = ingest (
930982 index_type = "IVF_FLAT" ,
0 commit comments