|
12 | 12 | import org.apache.lucene.search.BooleanClause; |
13 | 13 | import org.apache.lucene.search.BooleanQuery; |
14 | 14 | import org.apache.lucene.search.MatchNoDocsQuery; |
| 15 | +import org.apache.lucene.search.PatienceKnnVectorQuery; |
15 | 16 | import org.apache.lucene.search.Query; |
16 | 17 | import org.apache.lucene.search.knn.KnnSearchStrategy; |
17 | 18 | import org.elasticsearch.TransportVersion; |
@@ -206,11 +207,21 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que |
206 | 207 | switch (elementType()) { |
207 | 208 | case FLOAT -> assertThat( |
208 | 209 | query, |
209 | | - anyOf(instanceOf(ESKnnFloatVectorQuery.class), instanceOf(DenseVectorQuery.Floats.class), instanceOf(BooleanQuery.class)) |
| 210 | + anyOf( |
| 211 | + instanceOf(ESKnnFloatVectorQuery.class), |
| 212 | + instanceOf(DenseVectorQuery.Floats.class), |
| 213 | + instanceOf(BooleanQuery.class), |
| 214 | + instanceOf(PatienceKnnVectorQuery.class) |
| 215 | + ) |
210 | 216 | ); |
211 | 217 | case BYTE -> assertThat( |
212 | 218 | query, |
213 | | - anyOf(instanceOf(ESKnnByteVectorQuery.class), instanceOf(DenseVectorQuery.Bytes.class), instanceOf(BooleanQuery.class)) |
| 219 | + anyOf( |
| 220 | + instanceOf(ESKnnByteVectorQuery.class), |
| 221 | + instanceOf(DenseVectorQuery.Bytes.class), |
| 222 | + instanceOf(BooleanQuery.class), |
| 223 | + instanceOf(PatienceKnnVectorQuery.class) |
| 224 | + ) |
214 | 225 | ); |
215 | 226 | } |
216 | 227 |
|
@@ -278,7 +289,18 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que |
278 | 289 | if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) { |
279 | 290 | query = vectorSimilarityQuery.getInnerKnnQuery(); |
280 | 291 | } |
281 | | - assertThat(query, anyOf(equalTo(knnVectorQueryBuilt), equalTo(bruteForceVectorQueryBuilt))); |
| 292 | + assertThat( |
| 293 | + query, |
| 294 | + anyOf( |
| 295 | + equalTo(knnVectorQueryBuilt), |
| 296 | + equalTo( |
| 297 | + knnVectorQueryBuilt instanceof ESKnnByteVectorQuery esKnnByteVectorQuery |
| 298 | + ? PatienceKnnVectorQuery.fromByteQuery(esKnnByteVectorQuery) |
| 299 | + : PatienceKnnVectorQuery.fromFloatQuery((ESKnnFloatVectorQuery) knnVectorQueryBuilt) |
| 300 | + ), |
| 301 | + equalTo(bruteForceVectorQueryBuilt) |
| 302 | + ) |
| 303 | + ); |
282 | 304 | } |
283 | 305 |
|
284 | 306 | public void testWrongDimension() { |
|
0 commit comments