@@ -102,13 +102,14 @@ def test_index_with_incorrect_dimensions(tmp_path):
102102 # Wrong number of dimensions will raise a TypeError.
103103 with pytest .raises (TypeError ):
104104 index .query (np .array (1 , dtype = np .float32 ), k = 3 )
105+ with pytest .raises (TypeError ):
106+ index .query (np .array ([1 , 1 , 1 ], dtype = np .float32 ), k = 3 )
105107 with pytest .raises (TypeError ):
106108 index .query (np .array ([[[1 , 1 , 1 ]]], dtype = np .float32 ), k = 3 )
107109 with pytest .raises (TypeError ):
108110 index .query (np .array ([[[[1 , 1 , 1 ]]]], dtype = np .float32 ), k = 3 )
109111
110112 # Okay otherwise.
111- index .query (np .array ([1 , 1 , 1 ], dtype = np .float32 ), k = 3 )
112113 index .query (np .array ([[1 , 1 , 1 ]], dtype = np .float32 ), k = 3 )
113114
114115def test_index_with_incorrect_num_of_query_columns_simple (tmp_path ):
@@ -156,37 +157,3 @@ def test_index_with_incorrect_num_of_query_columns_complex(tmp_path):
156157 else :
157158 with pytest .raises (TypeError ):
158159 index .query (query , k = 1 )
159-
160- # TODO(paris): This will throw with the following error. Fix and re-enable, then remove
161- # test_index_with_incorrect_num_of_query_columns_in_single_vector_query:
162- # def array_to_matrix(array: np.ndarray):
163- # if array.dtype == np.float32:
164- # > return pyarray_copyto_matrix_f32(array)
165- # E RuntimeError: Number of dimensions must be two
166- # Here we test with a query which is just a vector, i.e. [1, 2, 3].
167- # query = query[0]
168- # if num_columns_for_query == num_columns:
169- # index.query(query, k=1)
170- # else:
171- # with pytest.raises(TypeError):
172- # index.query(query, k=1)
173-
174- def test_index_with_incorrect_num_of_query_columns_in_single_vector_query (tmp_path ):
175- # Tests that we raise a TypeError if the number of columns in the query is not the same as the
176- # number of columns in the indexed data, specifically for a single vector query.
177- # i.e. queries = [1, 2, 3] instead of queries = [[1, 2, 3], [4, 5, 6]].
178- indexes = [flat_index , ivf_flat_index ]
179- for index_type in indexes :
180- uri = os .path .join (tmp_path , f"array_{ index_type .__name__ } " )
181- index = index_type .create (uri = uri , dimensions = 3 , vector_type = np .dtype (np .uint8 ))
182-
183- # Wrong number of columns will raise a TypeError.
184- with pytest .raises (TypeError ):
185- index .query (np .array ([1 ], dtype = np .float32 ), k = 3 )
186- with pytest .raises (TypeError ):
187- index .query (np .array ([1 , 1 ], dtype = np .float32 ), k = 3 )
188- with pytest .raises (TypeError ):
189- index .query (np .array ([1 , 1 , 1 , 1 ], dtype = np .float32 ), k = 3 )
190-
191- # Okay otherwise.
192- index .query (np .array ([1 , 1 , 1 ], dtype = np .float32 ), k = 3 )
0 commit comments