17
17
import org .apache .lucene .search .Query ;
18
18
import org .apache .lucene .store .Directory ;
19
19
import org .apache .lucene .tests .index .RandomIndexWriter ;
20
+ import org .apache .lucene .util .Accountable ;
21
+ import org .elasticsearch .action .search .SearchShardTask ;
22
+ import org .elasticsearch .cluster .metadata .IndexMetadata ;
23
+ import org .elasticsearch .common .settings .Settings ;
24
+ import org .elasticsearch .index .IndexSettings ;
25
+ import org .elasticsearch .index .IndexVersion ;
26
+ import org .elasticsearch .index .cache .bitset .BitsetFilterCache ;
27
+ import org .elasticsearch .index .mapper .MapperMetrics ;
28
+ import org .elasticsearch .index .mapper .MappingLookup ;
29
+ import org .elasticsearch .index .query .ParsedQuery ;
30
+ import org .elasticsearch .index .query .SearchExecutionContext ;
31
+ import org .elasticsearch .index .shard .IndexShard ;
32
+ import org .elasticsearch .index .shard .IndexShardTestCase ;
33
+ import org .elasticsearch .index .shard .SearchOperationListener ;
34
+ import org .elasticsearch .index .shard .ShardId ;
35
+ import org .elasticsearch .search .builder .SearchSourceBuilder ;
20
36
import org .elasticsearch .search .internal .ContextIndexSearcher ;
37
+ import org .elasticsearch .search .internal .SearchContext ;
21
38
import org .elasticsearch .search .profile .Profilers ;
22
39
import org .elasticsearch .search .profile .SearchProfileDfsPhaseResult ;
23
40
import org .elasticsearch .search .profile .query .CollectorResult ;
24
41
import org .elasticsearch .search .profile .query .QueryProfileShardResult ;
25
- import org .elasticsearch .test .ESTestCase ;
42
+ import org .elasticsearch .search .vectors .KnnSearchBuilder ;
43
+ import org .elasticsearch .test .TestSearchContext ;
26
44
import org .elasticsearch .threadpool .TestThreadPool ;
27
45
import org .elasticsearch .threadpool .ThreadPool ;
28
46
import org .junit .After ;
29
47
import org .junit .Before ;
30
48
31
49
import java .io .IOException ;
50
+ import java .util .Collections ;
32
51
import java .util .List ;
33
52
import java .util .concurrent .ThreadPoolExecutor ;
53
+ import java .util .concurrent .atomic .AtomicLong ;
34
54
35
- public class DfsPhaseTests extends ESTestCase {
55
+ import static org .elasticsearch .search .dfs .DfsPhase .executeKnnVectorQuery ;
56
+
57
+ public class DfsPhaseTests extends IndexShardTestCase {
36
58
37
59
ThreadPoolExecutor threadPoolExecutor ;
38
60
private TestThreadPool threadPool ;
@@ -49,6 +71,105 @@ public void cleanup() {
49
71
terminate (threadPool );
50
72
}
51
73
74
+ public void testKnnSearch () throws IOException {
75
+ AtomicLong queryCount = new AtomicLong ();
76
+ AtomicLong queryTime = new AtomicLong ();
77
+
78
+ IndexShard indexShard = newShard (true , List .of (new SearchOperationListener () {
79
+ @ Override
80
+ public void onQueryPhase (SearchContext searchContext , long tookInNanos ) {
81
+ queryCount .incrementAndGet ();
82
+ queryTime .addAndGet (tookInNanos );
83
+ }
84
+ }));
85
+ try (Directory dir = newDirectory (); RandomIndexWriter w = new RandomIndexWriter (random (), dir , newIndexWriterConfig ())) {
86
+ int numDocs = randomIntBetween (900 , 1000 );
87
+ for (int i = 0 ; i < numDocs ; i ++) {
88
+ Document d = new Document ();
89
+ d .add (new KnnFloatVectorField ("float_vector" , new float [] { i , 0 , 0 }));
90
+ w .addDocument (d );
91
+ }
92
+ w .flush ();
93
+
94
+ IndexReader reader = w .getReader ();
95
+ ContextIndexSearcher searcher = new ContextIndexSearcher (
96
+ reader ,
97
+ IndexSearcher .getDefaultSimilarity (),
98
+ IndexSearcher .getDefaultQueryCache (),
99
+ IndexSearcher .getDefaultQueryCachingPolicy (),
100
+ randomBoolean (),
101
+ threadPoolExecutor ,
102
+ threadPoolExecutor .getMaximumPoolSize (),
103
+ 1
104
+ );
105
+ IndexSettings indexSettings = new IndexSettings (
106
+ IndexMetadata .builder ("index" )
107
+ .settings (Settings .builder ().put (IndexMetadata .SETTING_VERSION_CREATED , IndexVersion .current ()))
108
+ .numberOfShards (1 )
109
+ .numberOfReplicas (0 )
110
+ .creationDate (System .currentTimeMillis ())
111
+ .build (),
112
+ Settings .EMPTY
113
+ );
114
+ BitsetFilterCache bitsetFilterCache = new BitsetFilterCache (indexSettings , new BitsetFilterCache .Listener () {
115
+ @ Override
116
+ public void onCache (ShardId shardId , Accountable accountable ) {
117
+
118
+ }
119
+
120
+ @ Override
121
+ public void onRemoval (ShardId shardId , Accountable accountable ) {
122
+
123
+ }
124
+ });
125
+ SearchExecutionContext searchExecutionContext = new SearchExecutionContext (
126
+ 0 ,
127
+ 0 ,
128
+ indexSettings ,
129
+ bitsetFilterCache ,
130
+ null ,
131
+ null ,
132
+ MappingLookup .EMPTY ,
133
+ null ,
134
+ null ,
135
+ null ,
136
+ null ,
137
+ null ,
138
+ null ,
139
+ null ,
140
+ null ,
141
+ null ,
142
+ null ,
143
+ null ,
144
+ Collections .emptyMap (),
145
+ null ,
146
+ MapperMetrics .NOOP
147
+ );
148
+
149
+ Query query = new KnnFloatVectorQuery ("float_vector" , new float [] { 0 , 0 , 0 }, numDocs , null );
150
+ try (TestSearchContext context = new TestSearchContext (searchExecutionContext , indexShard , searcher ) {
151
+ @ Override
152
+ public DfsSearchResult dfsResult () {
153
+ return new DfsSearchResult (null , null , null );
154
+ }
155
+ }) {
156
+ context .request ()
157
+ .source (
158
+ new SearchSourceBuilder ().knnSearch (
159
+ List .of (new KnnSearchBuilder ("float_vector" , new float [] { 0 , 0 , 0 }, numDocs , numDocs , null , null ))
160
+ )
161
+ );
162
+ context .setTask (new SearchShardTask (123L , "" , "" , "" , null , Collections .emptyMap ()));
163
+ context .parsedQuery (new ParsedQuery (query ));
164
+ executeKnnVectorQuery (context );
165
+ assertTrue (queryCount .get () > 0 );
166
+ assertTrue (queryTime .get () > 0 );
167
+ reader .close ();
168
+ closeShards (indexShard );
169
+ }
170
+ }
171
+ }
172
+
52
173
public void testSingleKnnSearch () throws IOException {
53
174
try (Directory dir = newDirectory (); RandomIndexWriter w = new RandomIndexWriter (random (), dir , newIndexWriterConfig ())) {
54
175
int numDocs = randomIntBetween (900 , 1000 );
0 commit comments