Skip to content

Commit c7b5390

Browse files
committed
Add serialization BwC test
1 parent 06f29f5 commit c7b5390

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,23 @@
2020
import org.apache.lucene.search.Query;
2121
import org.apache.lucene.search.join.ScoreMode;
2222
import org.apache.lucene.tests.index.RandomIndexWriter;
23+
import org.elasticsearch.TransportVersion;
24+
import org.elasticsearch.TransportVersions;
2325
import org.elasticsearch.action.ActionListener;
2426
import org.elasticsearch.action.ActionRequest;
2527
import org.elasticsearch.action.ActionType;
2628
import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
2729
import org.elasticsearch.client.internal.Client;
2830
import org.elasticsearch.cluster.ClusterChangedEvent;
2931
import org.elasticsearch.cluster.metadata.IndexMetadata;
32+
import org.elasticsearch.common.CheckedBiConsumer;
3033
import org.elasticsearch.common.Strings;
3134
import org.elasticsearch.common.bytes.BytesReference;
3235
import org.elasticsearch.common.compress.CompressedXContent;
36+
import org.elasticsearch.common.io.stream.BytesStreamOutput;
37+
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
3338
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
39+
import org.elasticsearch.common.io.stream.StreamInput;
3440
import org.elasticsearch.common.settings.Settings;
3541
import org.elasticsearch.core.IOUtils;
3642
import org.elasticsearch.core.Nullable;
@@ -46,6 +52,7 @@
4652
import org.elasticsearch.index.query.QueryRewriteContext;
4753
import org.elasticsearch.index.query.SearchExecutionContext;
4854
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
55+
import org.elasticsearch.inference.InferenceResults;
4956
import org.elasticsearch.inference.InputType;
5057
import org.elasticsearch.inference.MinimalServiceSettings;
5158
import org.elasticsearch.inference.SimilarityMeasure;
@@ -55,6 +62,7 @@
5562
import org.elasticsearch.search.vectors.SparseVectorQueryWrapper;
5663
import org.elasticsearch.test.AbstractQueryTestCase;
5764
import org.elasticsearch.test.ClusterServiceUtils;
65+
import org.elasticsearch.test.TransportVersionUtils;
5866
import org.elasticsearch.test.client.NoOpClient;
5967
import org.elasticsearch.threadpool.TestThreadPool;
6068
import org.elasticsearch.xcontent.XContentBuilder;
@@ -86,6 +94,7 @@
8694
import static org.apache.lucene.search.BooleanClause.Occur.FILTER;
8795
import static org.apache.lucene.search.BooleanClause.Occur.MUST;
8896
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
97+
import static org.hamcrest.Matchers.containsString;
8998
import static org.hamcrest.Matchers.equalTo;
9099
import static org.hamcrest.Matchers.instanceOf;
91100
import static org.hamcrest.Matchers.notNullValue;
@@ -361,6 +370,97 @@ public void testIllegalValues() {
361370
}
362371
}
363372

373+
public void testSerializationBwc() throws IOException {
374+
InferenceResults inferenceResults1 = new TextExpansionResults(
375+
DEFAULT_RESULTS_FIELD,
376+
List.of(new WeightedToken("foo", 1.0f)),
377+
false
378+
);
379+
InferenceResults inferenceResults2 = new TextExpansionResults(
380+
DEFAULT_RESULTS_FIELD,
381+
List.of(new WeightedToken("bar", 2.0f)),
382+
false
383+
);
384+
385+
// Single inference result
386+
CheckedBiConsumer<InferenceResults, TransportVersion, IOException> assertSingleInferenceResult = (inferenceResults, version) -> {
387+
String fieldName = randomAlphaOfLength(5);
388+
String query = randomAlphaOfLength(5);
389+
390+
MapInferenceResultsProvider mapInferenceResultsProvider = new MapInferenceResultsProvider();
391+
mapInferenceResultsProvider.addInferenceResults(randomAlphaOfLength(5), inferenceResults);
392+
SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(fieldName, query, null, mapInferenceResultsProvider);
393+
394+
SingleInferenceResultsProvider singleInferenceResultsProvider = new SingleInferenceResultsProvider(inferenceResults);
395+
SemanticQueryBuilder bwcQuery = new SemanticQueryBuilder(fieldName, query, null, singleInferenceResultsProvider);
396+
397+
try (BytesStreamOutput output = new BytesStreamOutput()) {
398+
output.setTransportVersion(version);
399+
output.writeNamedWriteable(originalQuery);
400+
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) {
401+
in.setTransportVersion(version);
402+
QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class);
403+
404+
SemanticQueryBuilder expectedQuery = version.onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)
405+
? originalQuery
406+
: bwcQuery;
407+
assertThat(deserializedQuery, equalTo(expectedQuery));
408+
}
409+
}
410+
};
411+
412+
for (int i = 0; i < 100; i++) {
413+
TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween(
414+
random(),
415+
TransportVersions.V_8_15_0,
416+
TransportVersion.current()
417+
);
418+
assertSingleInferenceResult.accept(inferenceResults1, transportVersion);
419+
}
420+
421+
// Multiple inference results
422+
CheckedBiConsumer<List<InferenceResults>, TransportVersion, IOException> assertMultipleInferenceResults = (
423+
inferenceResultsList,
424+
version) -> {
425+
MapInferenceResultsProvider mapInferenceResultsProvider = new MapInferenceResultsProvider();
426+
inferenceResultsList.forEach(result -> mapInferenceResultsProvider.addInferenceResults(randomAlphaOfLength(5), result));
427+
SemanticQueryBuilder originalQuery = new SemanticQueryBuilder(
428+
randomAlphaOfLength(5),
429+
randomAlphaOfLength(5),
430+
null,
431+
mapInferenceResultsProvider
432+
);
433+
434+
try (BytesStreamOutput output = new BytesStreamOutput()) {
435+
output.setTransportVersion(version);
436+
437+
if (version.onOrAfter(TransportVersions.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS)) {
438+
output.writeNamedWriteable(originalQuery);
439+
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) {
440+
in.setTransportVersion(version);
441+
QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class);
442+
assertThat(deserializedQuery, equalTo(originalQuery));
443+
}
444+
} else {
445+
IllegalArgumentException e = assertThrows(
446+
IllegalArgumentException.class,
447+
() -> output.writeNamedWriteable(originalQuery)
448+
);
449+
assertThat(e.getMessage(), containsString("Cannot query multiple inference IDs in a mixed-version cluster"));
450+
}
451+
}
452+
};
453+
454+
for (int i = 0; i < 100; i++) {
455+
TransportVersion transportVersion = TransportVersionUtils.randomVersionBetween(
456+
random(),
457+
TransportVersions.V_8_15_0,
458+
TransportVersion.current()
459+
);
460+
assertMultipleInferenceResults.accept(List.of(inferenceResults1, inferenceResults2), transportVersion);
461+
}
462+
}
463+
364464
public void testToXContent() throws IOException {
365465
QueryBuilder queryBuilder = new SemanticQueryBuilder("foo", "bar");
366466
checkGeneratedJson("""

0 commit comments

Comments
 (0)