Skip to content

Commit 7c2dc4d

Browse files
committed
Added semantic text multi-match query test
1 parent 42c10c6 commit 7c2dc4d

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.apache.lucene.index.Term;
11+
import org.apache.lucene.search.DisjunctionMaxQuery;
12+
import org.apache.lucene.search.Query;
13+
import org.apache.lucene.search.TermQuery;
14+
import org.elasticsearch.cluster.ClusterChangedEvent;
15+
import org.elasticsearch.common.settings.Settings;
16+
import org.elasticsearch.core.IOUtils;
17+
import org.elasticsearch.index.mapper.MapperService;
18+
import org.elasticsearch.index.mapper.MapperServiceTestCase;
19+
import org.elasticsearch.index.mapper.ParsedDocument;
20+
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
21+
import org.elasticsearch.index.query.SearchExecutionContext;
22+
import org.elasticsearch.plugins.Plugin;
23+
import org.elasticsearch.test.ClusterServiceUtils;
24+
import org.elasticsearch.test.client.NoOpClient;
25+
import org.elasticsearch.threadpool.TestThreadPool;
26+
import org.elasticsearch.xpack.inference.InferencePlugin;
27+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
28+
import org.junit.AfterClass;
29+
import org.junit.BeforeClass;
30+
31+
import java.util.Collection;
32+
import java.util.List;
33+
import java.util.function.Supplier;
34+
35+
public class SemanticMultiMatchQueryBuilderTests extends MapperServiceTestCase {
36+
private static TestThreadPool threadPool;
37+
private static ModelRegistry modelRegistry;
38+
39+
private static class InferencePluginWithModelRegistry extends InferencePlugin {
40+
InferencePluginWithModelRegistry(Settings settings) {
41+
super(settings);
42+
}
43+
44+
@Override
45+
protected Supplier<ModelRegistry> getModelRegistry() {
46+
return () -> modelRegistry;
47+
}
48+
}
49+
50+
@BeforeClass
51+
public static void startModelRegistry() {
52+
threadPool = new TestThreadPool(SemanticMultiMatchQueryBuilderTests.class.getName());
53+
var clusterService = ClusterServiceUtils.createClusterService(threadPool);
54+
modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
55+
modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
56+
@Override
57+
public boolean localNodeMaster() {
58+
return false;
59+
}
60+
});
61+
}
62+
63+
@AfterClass
64+
public static void stopModelRegistry() {
65+
IOUtils.closeWhileHandlingException(threadPool);
66+
}
67+
68+
@Override
69+
protected Collection<? extends Plugin> getPlugins() {
70+
return List.of(new InferencePluginWithModelRegistry(Settings.EMPTY));
71+
}
72+
73+
public void testResolveSemanticTextFieldFromWildcard() throws Exception {
74+
MapperService mapperService = createMapperService("""
75+
{
76+
"_doc" : {
77+
"properties": {
78+
"text_field": { "type": "text" },
79+
"keyword_field": { "type": "keyword" },
80+
"inference_field": { "type": "semantic_text", "inference_id": "test_service" }
81+
}
82+
}
83+
}
84+
""");
85+
86+
ParsedDocument doc = mapperService.documentMapper().parse(source("""
87+
{
88+
"text_field" : "foo",
89+
"keyword_field" : "foo",
90+
"inference_field" : "foo",
91+
"_inference_fields": {
92+
"inference_field": {
93+
"inference": {
94+
"inference_id": "test_service",
95+
"model_settings": {
96+
"task_type": "sparse_embedding"
97+
},
98+
"chunks": {
99+
"inference_field": [
100+
{
101+
"start_offset": 0,
102+
"end_offset": 3,
103+
"embeddings": {
104+
"foo": 1.0
105+
}
106+
}
107+
]
108+
}
109+
}
110+
}
111+
}
112+
}
113+
"""));
114+
115+
withLuceneIndex(mapperService, iw -> iw.addDocument(doc.rootDoc()), ir -> {
116+
SearchExecutionContext context = createSearchExecutionContext(mapperService, newSearcher(ir));
117+
Query query = new MultiMatchQueryBuilder("foo", "*_field").toQuery(context);
118+
Query expected = new DisjunctionMaxQuery(
119+
List.of(new TermQuery(new Term("text_field", "foo")), new TermQuery(new Term("keyword_field", "foo"))),
120+
0f
121+
);
122+
assertEquals(expected, query);
123+
});
124+
}
125+
}

0 commit comments

Comments
 (0)