Skip to content

Commit e2e65db

Browse files
committed
additional tests and refactoring
1 parent 4a23c9c commit e2e65db

File tree

2 files changed

+38
-44
lines changed

2 files changed

+38
-44
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import java.util.Map;
4343
import java.util.Objects;
4444

45+
import static org.elasticsearch.TransportVersions.SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19;
4546
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
4647
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
4748
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
@@ -124,7 +125,9 @@ public SparseVectorQueryBuilder(
124125
public SparseVectorQueryBuilder(StreamInput in) throws IOException {
125126
super(in);
126127
this.fieldName = in.readString();
127-
if (in.getTransportVersion().onOrAfter(TransportVersions.SPARSE_VECTOR_FIELD_PRUNING_OPTIONS)) {
128+
if (in.getTransportVersion().isPatchFrom(SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19) ||
129+
in.getTransportVersion().onOrAfter(TransportVersions.SPARSE_VECTOR_FIELD_PRUNING_OPTIONS)
130+
) {
128131
this.shouldPruneTokens = in.readOptionalBoolean();
129132
} else {
130133
this.shouldPruneTokens = in.readBoolean();
@@ -177,7 +180,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
177180
}
178181

179182
out.writeString(fieldName);
180-
if (out.getTransportVersion().onOrAfter(TransportVersions.SPARSE_VECTOR_FIELD_PRUNING_OPTIONS)) {
183+
if (out.getTransportVersion().isPatchFrom(SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19)
184+
|| out.getTransportVersion().onOrAfter(TransportVersions.SPARSE_VECTOR_FIELD_PRUNING_OPTIONS)
185+
) {
181186
out.writeOptionalBoolean(shouldPruneTokens);
182187
} else {
183188
out.writeBoolean(shouldPruneTokens != null && shouldPruneTokens);

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,26 @@ protected Object simulateMethod(Method method, Object[] args) {
155155

156156
@Override
157157
protected void initializeAdditionalMappings(MapperService mapperService) throws IOException {
158-
if (shouldInjectSparseVectorIndexOptions()) {
159-
addSparseVectorIndexOptionsMapping(mapperService);
160-
return;
161-
}
162-
163158
mapperService.merge(
164159
"_doc",
165-
new CompressedXContent(Strings.toString(PutMappingRequest.simpleMapping(SPARSE_VECTOR_FIELD, "type=sparse_vector"))),
160+
new CompressedXContent(getTestSparseVectorIndexMapping()),
166161
MapperService.MergeReason.MAPPING_UPDATE
167162
);
168163
}
169164

170-
private boolean shouldInjectSparseVectorIndexOptions() {
165+
private String getTestSparseVectorIndexMapping() {
166+
if (currentTestHasIndexOptions()) {
167+
return "{\"properties\":{\""
168+
+ SPARSE_VECTOR_FIELD
169+
+ "\":{\"type\":\"sparse_vector\",\"index_options\""
170+
+ ":{\"prune\":true,\"pruning_config\":{\"tokens_freq_ratio_threshold\""
171+
+ ":12,\"tokens_weight_threshold\":0.6}}}}}";
172+
}
173+
174+
return Strings.toString(PutMappingRequest.simpleMapping(SPARSE_VECTOR_FIELD, "type=sparse_vector"));
175+
}
176+
177+
private boolean currentTestHasIndexOptions() {
171178
Class<?> clazz = this.getClass();
172179
Class<InjectSparseVectorIndexOptions> injectSparseVectorIndexOptions = InjectSparseVectorIndexOptions.class;
173180

@@ -179,15 +186,6 @@ private boolean shouldInjectSparseVectorIndexOptions() {
179186
}
180187
}
181188

182-
private void addSparseVectorIndexOptionsMapping(MapperService mapperService) throws IOException {
183-
String addIndexOptionsTemplate = "{\"properties\":{\""
184-
+ SPARSE_VECTOR_FIELD
185-
+ "\":{\"type\":\"sparse_vector\",\"index_options\""
186-
+ ":{\"prune\":true,\"pruning_config\":{\"tokens_freq_ratio_threshold\""
187-
+ ":12,\"tokens_weight_threshold\":0.6}}}}}";
188-
mapperService.merge("_doc", new CompressedXContent(addIndexOptionsTemplate), MapperService.MergeReason.MAPPING_UPDATE);
189-
}
190-
191189
@Override
192190
protected void doAssertLuceneQuery(SparseVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) {
193191
assertThat(query, instanceOf(SparseVectorQueryWrapper.class));
@@ -283,7 +281,7 @@ private void testDoToQuery(SparseVectorQueryBuilder queryBuilder, SearchExecutio
283281

284282
assertTrue(query instanceof SparseVectorQueryWrapper);
285283
var sparseQuery = (SparseVectorQueryWrapper) query;
286-
if (queryBuilder.shouldPruneTokens()) {
284+
if (queryBuilder.shouldPruneTokens() || currentTestHasIndexOptions()) {
287285
// It's possible that all documents were pruned for aggressive pruning configurations
288286
assertTrue(sparseQuery.getTermsQuery() instanceof BooleanQuery || sparseQuery.getTermsQuery() instanceof MatchNoDocsQuery);
289287
} else {
@@ -386,31 +384,33 @@ public void testThatWeCorrectlyRewriteQueryIntoVectors() {
386384
public void testItUsesIndexOptionsDefaults() throws IOException {
387385
withSearchIndex((context) -> {
388386
try {
389-
SparseVectorQueryBuilder builder = new SparseVectorQueryBuilder(
390-
SPARSE_VECTOR_FIELD,
391-
WEIGHTED_TOKENS,
392-
null,
393-
null,
394-
null,
395-
null
396-
);
397-
Query query = builder.doToQuery(context);
398-
387+
SparseVectorQueryBuilder builder = createTestQueryBuilder(null);
388+
assertFalse(builder.shouldPruneTokens());
389+
testDoToQuery(builder, context);
399390
} catch (IOException ex) {
400391
throw new RuntimeException(ex);
401392
}
402393
});
403394
}
404395

405396
@InjectSparseVectorIndexOptions
406-
public void testItOverridesIndexOptionsDefaults() {
407-
397+
public void testItOverridesIndexOptionsDefaults() throws IOException {
398+
withSearchIndex((context) -> {
399+
try {
400+
TokenPruningConfig pruningConfig = new TokenPruningConfig(2, 0.3f, false);
401+
SparseVectorQueryBuilder builder = createTestQueryBuilder(pruningConfig);
402+
assertTrue(builder.shouldPruneTokens());
403+
testDoToQuery(builder, context);
404+
} catch (IOException ex) {
405+
throw new RuntimeException(ex);
406+
}
407+
});
408408
}
409409

410410
@InjectSparseVectorIndexOptions
411-
public void testToQueryWithIndexOptions() throws IOException {
411+
public void testToQueryRewriteWithIndexOptions() throws IOException {
412412
withSearchIndex((context) -> {
413-
SparseVectorQueryBuilder queryBuilder = createTestQueryBuilder();
413+
SparseVectorQueryBuilder queryBuilder = createTestQueryBuilder(null);
414414
try {
415415
if (queryBuilder.getQueryVectors() == null) {
416416
QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, context);
@@ -424,15 +424,4 @@ public void testToQueryWithIndexOptions() throws IOException {
424424
}
425425
});
426426
}
427-
428-
@InjectSparseVectorIndexOptions
429-
public void testWeCorrectlyRewriteQueryIntoVectorsWithIndexOptions() {
430-
SearchExecutionContext searchExecutionContext = createSearchExecutionContext();
431-
432-
SparseVectorQueryBuilder queryBuilder = createTestQueryBuilder(null);
433-
QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext);
434-
assertTrue(rewrittenQueryBuilder instanceof SparseVectorQueryBuilder);
435-
assertEquals(queryBuilder.shouldPruneTokens(), ((SparseVectorQueryBuilder) rewrittenQueryBuilder).shouldPruneTokens());
436-
assertNotNull(((SparseVectorQueryBuilder) rewrittenQueryBuilder).getQueryVectors());
437-
}
438427
}

0 commit comments

Comments
 (0)