Skip to content

Commit 96f5aa6

Browse files
refactoring unit tests
1 parent f73285d commit 96f5aa6

File tree

3 files changed

+30
-73
lines changed

3 files changed

+30
-73
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,41 +54,20 @@ public void cleanup() {
5454
}
5555

5656
public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOException {
57+
float boost = randomFloat() * 5;
58+
String queryName = randomAlphaOfLength(5);
5759
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
5860
FIELD_NAME,
5961
new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }, null)
6062
);
6163
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
6264
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY);
6365
KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
66+
original.boost(boost);
67+
original.queryName(queryName);
6468
testRewrittenInferenceQuery(context, original);
6569
}
6670

67-
public void testKnnQueryWithVectorBuilderIsInterceptedAndRewrittenWithBoostAndQueryName() throws IOException {
68-
float BOOST = 5.0f;
69-
String QUERY_NAME = "knn_query";
70-
71-
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
72-
FIELD_NAME,
73-
new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }, null)
74-
);
75-
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
76-
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY);
77-
KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
78-
original.boost(BOOST);
79-
original.queryName(QUERY_NAME);
80-
81-
testRewrittenInferenceQuery(context, original);
82-
QueryBuilder rewritten = original.rewrite(context);
83-
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
84-
assertEquals(BOOST, intercepted.boost(), 0.0f);
85-
assertEquals(QUERY_NAME, intercepted.queryName());
86-
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
87-
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) nestedQueryBuilder.query();
88-
assertEquals(BOOST, knnVectorQueryBuilder.boost(), 5.0f);
89-
assertNull(knnVectorQueryBuilder.queryName());
90-
}
91-
9271
public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
9372
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
9473
FIELD_NAME,
@@ -107,14 +86,23 @@ private void testRewrittenInferenceQuery(QueryRewriteContext context, KnnVectorQ
10786
rewritten instanceof InterceptedQueryBuilderWrapper
10887
);
10988
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
89+
assertEquals(original.boost(), intercepted.boost(), 0.0f);
90+
assertEquals(original.queryName(), intercepted.queryName());
11091
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
92+
11193
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
94+
assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
95+
assertEquals(original.queryName(), nestedQueryBuilder.queryName());
11296
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
97+
11398
QueryBuilder innerQuery = nestedQueryBuilder.query();
11499
assertTrue(innerQuery instanceof KnnVectorQueryBuilder);
115100
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) innerQuery;
101+
assertEquals(1.0f, knnVectorQueryBuilder.boost(), 0.0f);
102+
assertNull(knnVectorQueryBuilder.queryName());
116103
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), knnVectorQueryBuilder.getFieldName());
117104
assertTrue(knnVectorQueryBuilder.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder);
105+
118106
TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder) knnVectorQueryBuilder
119107
.queryVectorBuilder();
120108
assertEquals(QUERY, textEmbeddingQueryVectorBuilder.getModelText());

x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ public void testMatchQueryOnNonInferenceFieldRemainsMatchQuery() throws IOExcept
8181
assertEquals(original, rewritten);
8282
}
8383

84-
public void testBoostInMatchQueryRewrite() throws IOException {
84+
public void testBoostAndQueryNameInMatchQueryRewrite() throws IOException {
8585
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
8686
FIELD_NAME,
8787
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null)
8888
);
8989
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
90-
QueryBuilder original = createTestQueryBuilderWithBoost();
90+
QueryBuilder original = createTestQueryBuilderWithBoostAndQueryName();
9191
QueryBuilder rewritten = original.rewrite(context);
9292
assertTrue(
9393
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
@@ -106,7 +106,7 @@ private MatchQueryBuilder createTestQueryBuilder() {
106106
return new MatchQueryBuilder(FIELD_NAME, VALUE);
107107
}
108108

109-
private MatchQueryBuilder createTestQueryBuilderWithBoost() {
109+
private MatchQueryBuilder createTestQueryBuilderWithBoostAndQueryName() {
110110
MatchQueryBuilder queryBuilder = createTestQueryBuilder();
111111
queryBuilder.boost(BOOST);
112112
queryBuilder.queryName(QUERY_NAME);

x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java

Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,17 @@ public void cleanup() {
5252
}
5353

5454
public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException {
55+
float boost = randomFloat() * 5;
56+
String queryName = randomAlphaOfLength(5);
5557
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
5658
FIELD_NAME,
5759
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null)
5860
);
5961
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
6062
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
61-
QueryBuilder rewritten = original.rewrite(context);
62-
assertTrue(
63-
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
64-
rewritten instanceof InterceptedQueryBuilderWrapper
65-
);
66-
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
67-
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
68-
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
69-
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
70-
QueryBuilder innerQuery = nestedQueryBuilder.query();
71-
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
72-
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
73-
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
74-
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
75-
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
63+
original.boost(boost);
64+
original.queryName(queryName);
65+
testRewrittenInferenceQuery(context, original);
7666
}
7767

7868
public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
@@ -82,21 +72,7 @@ public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsIntercepted
8272
);
8373
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
8474
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY);
85-
QueryBuilder rewritten = original.rewrite(context);
86-
assertTrue(
87-
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
88-
rewritten instanceof InterceptedQueryBuilderWrapper
89-
);
90-
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
91-
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
92-
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
93-
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
94-
QueryBuilder innerQuery = nestedQueryBuilder.query();
95-
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
96-
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
97-
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
98-
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
99-
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
75+
testRewrittenInferenceQuery(context, original);
10076
}
10177

10278
public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
@@ -110,36 +86,29 @@ public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IO
11086
assertEquals(original, rewritten);
11187
}
11288

113-
public void testBoostAndQueryNameOnSparseVectorQueryRewrite() throws IOException {
114-
float BOOST = 5.0f;
115-
String QUERY_NAME = "sparse_vector_query";
116-
117-
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
118-
FIELD_NAME,
119-
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null)
120-
);
121-
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
122-
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
123-
original.boost(BOOST);
124-
original.queryName(QUERY_NAME);
89+
private void testRewrittenInferenceQuery(QueryRewriteContext context, QueryBuilder original) throws IOException {
12590
QueryBuilder rewritten = original.rewrite(context);
12691
assertTrue(
12792
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
12893
rewritten instanceof InterceptedQueryBuilderWrapper
12994
);
13095
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
131-
assertEquals(BOOST, intercepted.boost(), 0.0f);
132-
assertEquals(QUERY_NAME, intercepted.queryName());
96+
assertEquals(original.boost(), intercepted.boost(), 0.0f);
97+
assertEquals(original.queryName(), intercepted.queryName());
98+
13399
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
134100
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
135101
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
102+
assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
103+
assertEquals(original.queryName(), nestedQueryBuilder.queryName());
104+
136105
QueryBuilder innerQuery = nestedQueryBuilder.query();
137106
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
138107
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
139108
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
140109
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
141110
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
142-
assertEquals(BOOST, sparseVectorQueryBuilder.boost(), 5.0f);
111+
assertEquals(1.0f, sparseVectorQueryBuilder.boost(), 0.0f);
143112
assertNull(sparseVectorQueryBuilder.queryName());
144113
}
145114

0 commit comments

Comments
 (0)