|
20 | 20 | import org.apache.lucene.search.Query; |
21 | 21 | import org.apache.lucene.search.join.ScoreMode; |
22 | 22 | import org.apache.lucene.tests.index.RandomIndexWriter; |
| 23 | +import org.elasticsearch.TransportVersion; |
| 24 | +import org.elasticsearch.TransportVersions; |
23 | 25 | import org.elasticsearch.action.ActionListener; |
24 | 26 | import org.elasticsearch.action.ActionRequest; |
25 | 27 | import org.elasticsearch.action.ActionType; |
26 | 28 | import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; |
27 | 29 | import org.elasticsearch.client.internal.Client; |
28 | 30 | import org.elasticsearch.cluster.ClusterChangedEvent; |
29 | 31 | import org.elasticsearch.cluster.metadata.IndexMetadata; |
| 32 | +import org.elasticsearch.common.CheckedBiConsumer; |
30 | 33 | import org.elasticsearch.common.Strings; |
31 | 34 | import org.elasticsearch.common.bytes.BytesReference; |
32 | 35 | import org.elasticsearch.common.compress.CompressedXContent; |
| 36 | +import org.elasticsearch.common.io.stream.BytesStreamOutput; |
| 37 | +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; |
33 | 38 | import org.elasticsearch.common.io.stream.NamedWriteableRegistry; |
| 39 | +import org.elasticsearch.common.io.stream.StreamInput; |
34 | 40 | import org.elasticsearch.common.settings.Settings; |
35 | 41 | import org.elasticsearch.core.IOUtils; |
36 | 42 | import org.elasticsearch.core.Nullable; |
|
46 | 52 | import org.elasticsearch.index.query.QueryRewriteContext; |
47 | 53 | import org.elasticsearch.index.query.SearchExecutionContext; |
48 | 54 | import org.elasticsearch.index.search.ESToParentBlockJoinQuery; |
| 55 | +import org.elasticsearch.inference.InferenceResults; |
49 | 56 | import org.elasticsearch.inference.InputType; |
50 | 57 | import org.elasticsearch.inference.MinimalServiceSettings; |
51 | 58 | import org.elasticsearch.inference.SimilarityMeasure; |
|
55 | 62 | import org.elasticsearch.search.vectors.SparseVectorQueryWrapper; |
56 | 63 | import org.elasticsearch.test.AbstractQueryTestCase; |
57 | 64 | import org.elasticsearch.test.ClusterServiceUtils; |
| 65 | +import org.elasticsearch.test.TransportVersionUtils; |
58 | 66 | import org.elasticsearch.test.client.NoOpClient; |
59 | 67 | import org.elasticsearch.threadpool.TestThreadPool; |
60 | 68 | import org.elasticsearch.xcontent.XContentBuilder; |
|
86 | 94 | import static org.apache.lucene.search.BooleanClause.Occur.FILTER; |
87 | 95 | import static org.apache.lucene.search.BooleanClause.Occur.MUST; |
88 | 96 | import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; |
| 97 | +import static org.hamcrest.Matchers.containsString; |
89 | 98 | import static org.hamcrest.Matchers.equalTo; |
90 | 99 | import static org.hamcrest.Matchers.instanceOf; |
91 | 100 | import static org.hamcrest.Matchers.notNullValue; |
@@ -361,6 +370,97 @@ public void testIllegalValues() { |
361 | 370 | } |
362 | 371 | } |
363 | 372 |
|
| 373 | + public void testSerializationBwc() throws IOException { |
| 374 | + InferenceResults inferenceResults1 = new TextExpansionResults( |
| 375 | + DEFAULT_RESULTS_FIELD, |
| 376 | + List.of(new WeightedToken("foo", 1.0f)), |
| 377 | + false |
| 378 | + ); |
| 379 | + InferenceResults inferenceResults2 = new TextExpansionResults( |
| 380 | + DEFAULT_RESULTS_FIELD, |
| 381 | + List.of(new WeightedToken("bar", 2.0f)), |
| 382 | + false |
| 383 | + ); |
| 384 | + |
| 385 | + // Single inference result |
| 386 | + CheckedBiConsumer<InferenceResults, TransportVersion, IOException> assertSingleInferenceResult = (inferenceResults, version) -> { |
| 387 | + String fieldName = randomAlphaOfLength(5); |
| 388 | + String query = randomAlphaOfLength(5); |
| 389 | + |
| 390 | + MapInferenceResultsProvider mapInferenceResultsProvider = new MapInferenceResultsProvider(); |
| 391 | + mapInferenceResultsProvider.addInferenceResults(randomAlphaOfLength(5), inferenceResults); |
| 392 | + SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(fieldName, query, null, mapInferenceResultsProvider); |
| 393 | + |
| 394 | + SingleInferenceResultsProvider singleInferenceResultsProvider = new SingleInferenceResultsProvider(inferenceResults); |
| 395 | + SemanticQueryBuilder bwcQuery = new SemanticQueryBuilder(fieldName, query, null, singleInferenceResultsProvider); |
| 396 | + |
| 397 | + try (BytesStreamOutput output = new BytesStreamOutput()) { |
| 398 | + output.setTransportVersion(version); |
| 399 | + output.writeNamedWriteable(originalQuery); |
| 400 | + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) { |
| 401 | + in.setTransportVersion(version); |
| 402 | + QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); |
| 403 | + |
| 404 | + SemanticQueryBuilder expectedQuery = version.onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS) |
| 405 | + ? originalQuery |
| 406 | + : bwcQuery; |
| 407 | + assertThat(deserializedQuery, equalTo(expectedQuery)); |
| 408 | + } |
| 409 | + } |
| 410 | + }; |
| 411 | + |
| 412 | + for (int i = 0; i < 100; i++) { |
| 413 | + TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween( |
| 414 | + random(), |
| 415 | + TransportVersions.V_8_15_0, |
| 416 | + TransportVersion.current() |
| 417 | + ); |
| 418 | + assertSingleInferenceResult.accept(inferenceResults1, transportVersion); |
| 419 | + } |
| 420 | + |
| 421 | + // Multiple inference results |
| 422 | + CheckedBiConsumer<List<InferenceResults>, TransportVersion, IOException> assertMultipleInferenceResults = ( |
| 423 | + inferenceResultsList, |
| 424 | + version) -> { |
| 425 | + MapInferenceResultsProvider mapInferenceResultsProvider = new MapInferenceResultsProvider(); |
| 426 | + inferenceResultsList.forEach(result -> mapInferenceResultsProvider.addInferenceResults(randomAlphaOfLength(5), result)); |
| 427 | + SemanticQueryBuilder originalQuery = new SemanticQueryBuilder( |
| 428 | + randomAlphaOfLength(5), |
| 429 | + randomAlphaOfLength(5), |
| 430 | + null, |
| 431 | + mapInferenceResultsProvider |
| 432 | + ); |
| 433 | + |
| 434 | + try (BytesStreamOutput output = new BytesStreamOutput()) { |
| 435 | + output.setTransportVersion(version); |
| 436 | + |
| 437 | + if (version.onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) { |
| 438 | + output.writeNamedWriteable(originalQuery); |
| 439 | + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) { |
| 440 | + in.setTransportVersion(version); |
| 441 | + QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); |
| 442 | + assertThat(deserializedQuery, equalTo(originalQuery)); |
| 443 | + } |
| 444 | + } else { |
| 445 | + IllegalArgumentException e = assertThrows( |
| 446 | + IllegalArgumentException.class, |
| 447 | + () -> output.writeNamedWriteable(originalQuery) |
| 448 | + ); |
| 449 | + assertThat(e.getMessage(), containsString("Cannot query multiple inference IDs in a mixed-version cluster")); |
| 450 | + } |
| 451 | + } |
| 452 | + }; |
| 453 | + |
| 454 | + for (int i = 0; i < 100; i++) { |
| 455 | + TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween( |
| 456 | + random(), |
| 457 | + TransportVersions.V_8_15_0, |
| 458 | + TransportVersion.current() |
| 459 | + ); |
| 460 | + assertMultipleInferenceResults.accept(List.of(inferenceResults1, inferenceResults2), transportVersion); |
| 461 | + } |
| 462 | + } |
| 463 | + |
364 | 464 | public void testToXContent() throws IOException { |
365 | 465 | QueryBuilder queryBuilder = new SemanticQueryBuilder("foo", "bar"); |
366 | 466 | checkGeneratedJson(""" |
|
0 commit comments