99
1010package org .elasticsearch .search .retriever ;
1111
12+ import org .apache .lucene .util .SetOnce ;
1213import org .elasticsearch .common .ParsingException ;
1314import org .elasticsearch .features .NodeFeature ;
1415import org .elasticsearch .index .query .BoolQueryBuilder ;
1516import org .elasticsearch .index .query .QueryBuilder ;
17+ import org .elasticsearch .index .query .QueryRewriteContext ;
1618import org .elasticsearch .search .builder .SearchSourceBuilder ;
1719import org .elasticsearch .search .retriever .rankdoc .RankDocsQueryBuilder ;
1820import org .elasticsearch .search .vectors .ExactKnnQueryBuilder ;
2931import java .util .Arrays ;
3032import java .util .List ;
3133import java .util .Objects ;
34+ import java .util .function .Supplier ;
3235
36+ import static org .elasticsearch .common .Strings .format ;
3337import static org .elasticsearch .xcontent .ConstructingObjectParser .constructorArg ;
3438import static org .elasticsearch .xcontent .ConstructingObjectParser .optionalConstructorArg ;
3539
@@ -96,7 +100,7 @@ public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
96100 }
97101
98102 private final String field ;
99- private final float [] queryVector ;
103+ private final Supplier < float []> queryVector ;
100104 private final QueryVectorBuilder queryVectorBuilder ;
101105 private final int k ;
102106 private final int numCands ;
@@ -110,23 +114,85 @@ public KnnRetrieverBuilder(
110114 int numCands ,
111115 Float similarity
112116 ) {
117+ if (queryVector == null && queryVectorBuilder == null ) {
118+ throw new IllegalArgumentException (
119+ format (
120+ "either [%s] or [%s] must be provided" ,
121+ QUERY_VECTOR_FIELD .getPreferredName (),
122+ QUERY_VECTOR_BUILDER_FIELD .getPreferredName ()
123+ )
124+ );
125+ } else if (queryVector != null && queryVectorBuilder != null ) {
126+ throw new IllegalArgumentException (
127+ format (
128+ "only one of [%s] and [%s] must be provided" ,
129+ QUERY_VECTOR_FIELD .getPreferredName (),
130+ QUERY_VECTOR_BUILDER_FIELD .getPreferredName ()
131+ )
132+ );
133+ }
113134 this .field = field ;
114- this .queryVector = queryVector ;
135+ this .queryVector = queryVector != null ? () -> queryVector : null ;
115136 this .queryVectorBuilder = queryVectorBuilder ;
116137 this .k = k ;
117138 this .numCands = numCands ;
118139 this .similarity = similarity ;
119140 }
120141
121- // ---- FOR TESTING XCONTENT PARSING ----
142+ private KnnRetrieverBuilder (KnnRetrieverBuilder clone , Supplier <float []> queryVector , QueryVectorBuilder queryVectorBuilder ) {
143+ this .queryVector = queryVector ;
144+ this .queryVectorBuilder = queryVectorBuilder ;
145+ this .field = clone .field ;
146+ this .k = clone .k ;
147+ this .numCands = clone .numCands ;
148+ this .similarity = clone .similarity ;
149+ this .retrieverName = clone .retrieverName ;
150+ this .preFilterQueryBuilders = clone .preFilterQueryBuilders ;
151+ }
122152
123153 @ Override
124154 public String getName () {
125155 return NAME ;
126156 }
127157
158+ @ Override
159+ public RetrieverBuilder rewrite (QueryRewriteContext ctx ) throws IOException {
160+ var rewrittenFilters = rewritePreFilters (ctx );
161+ if (rewrittenFilters != preFilterQueryBuilders ) {
162+ var rewritten = new KnnRetrieverBuilder (this , queryVector , queryVectorBuilder );
163+ rewritten .preFilterQueryBuilders = rewrittenFilters ;
164+ return rewritten ;
165+ }
166+
167+ if (queryVectorBuilder != null ) {
168+ SetOnce <float []> toSet = new SetOnce <>();
169+ ctx .registerAsyncAction ((c , l ) -> {
170+ queryVectorBuilder .buildVector (c , l .delegateFailureAndWrap ((ll , v ) -> {
171+ toSet .set (v );
172+ if (v == null ) {
173+ ll .onFailure (
174+ new IllegalArgumentException (
175+ format (
176+ "[%s] with name [%s] returned null query_vector" ,
177+ QUERY_VECTOR_BUILDER_FIELD .getPreferredName (),
178+ queryVectorBuilder .getWriteableName ()
179+ )
180+ )
181+ );
182+ return ;
183+ }
184+ ll .onResponse (null );
185+ }));
186+ });
187+ var rewritten = new KnnRetrieverBuilder (this , () -> toSet .get (), null );
188+ return rewritten ;
189+ }
190+ return super .rewrite (ctx );
191+ }
192+
128193 @ Override
129194 public QueryBuilder topDocsQuery () {
195+ assert queryVector != null : "query vector must be materialized at this point" ;
130196 assert rankDocs != null : "rankDocs should have been materialized by now" ;
131197 var rankDocsQuery = new RankDocsQueryBuilder (rankDocs , null , true );
132198 if (preFilterQueryBuilders .isEmpty ()) {
@@ -139,10 +205,11 @@ public QueryBuilder topDocsQuery() {
139205
140206 @ Override
141207 public QueryBuilder explainQuery () {
208+ assert queryVector != null : "query vector must be materialized at this point" ;
142209 assert rankDocs != null : "rankDocs should have been materialized by now" ;
143210 var rankDocsQuery = new RankDocsQueryBuilder (
144211 rankDocs ,
145- new QueryBuilder [] { new ExactKnnQueryBuilder (VectorData .fromFloats (queryVector ), field , similarity ) },
212+ new QueryBuilder [] { new ExactKnnQueryBuilder (VectorData .fromFloats (queryVector . get () ), field , similarity ) },
146213 true
147214 );
148215 if (preFilterQueryBuilders .isEmpty ()) {
@@ -155,10 +222,11 @@ public QueryBuilder explainQuery() {
155222
156223 @ Override
157224 public void extractToSearchSourceBuilder (SearchSourceBuilder searchSourceBuilder , boolean compoundUsed ) {
225+ assert queryVector != null : "query vector must be materialized at this point." ;
158226 KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder (
159227 field ,
160- VectorData .fromFloats (queryVector ),
161- queryVectorBuilder ,
228+ VectorData .fromFloats (queryVector . get () ),
229+ null ,
162230 k ,
163231 numCands ,
164232 similarity
@@ -174,14 +242,16 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
174242 searchSourceBuilder .knnSearch (knnSearchBuilders );
175243 }
176244
245+ // ---- FOR TESTING XCONTENT PARSING ----
246+
177247 @ Override
178248 public void doToXContent (XContentBuilder builder , Params params ) throws IOException {
179249 builder .field (FIELD_FIELD .getPreferredName (), field );
180250 builder .field (K_FIELD .getPreferredName (), k );
181251 builder .field (NUM_CANDS_FIELD .getPreferredName (), numCands );
182252
183253 if (queryVector != null ) {
184- builder .field (QUERY_VECTOR_FIELD .getPreferredName (), queryVector );
254+ builder .field (QUERY_VECTOR_FIELD .getPreferredName (), queryVector . get () );
185255 }
186256
187257 if (queryVectorBuilder != null ) {
@@ -199,15 +269,16 @@ public boolean doEquals(Object o) {
199269 return k == that .k
200270 && numCands == that .numCands
201271 && Objects .equals (field , that .field )
202- && Arrays .equals (queryVector , that .queryVector )
272+ && ((queryVector == null && that .queryVector == null )
273+ || (queryVector != null && that .queryVector != null && Arrays .equals (queryVector .get (), that .queryVector .get ())))
203274 && Objects .equals (queryVectorBuilder , that .queryVectorBuilder )
204275 && Objects .equals (similarity , that .similarity );
205276 }
206277
207278 @ Override
208279 public int doHashCode () {
209280 int result = Objects .hash (field , queryVectorBuilder , k , numCands , similarity );
210- result = 31 * result + Arrays .hashCode (queryVector );
281+ result = 31 * result + Arrays .hashCode (queryVector != null ? queryVector . get () : null );
211282 return result ;
212283 }
213284
0 commit comments