@@ -52,20 +52,16 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI
5252 assert (queryBuilder instanceof KnnVectorQueryBuilder );
5353 KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder ) queryBuilder ;
5454 Map <String , List <String >> inferenceIdsIndices = indexInformation .getInferenceIdsIndices ();
55- QueryBuilder finalQueryBuilder ;
5655 if (inferenceIdsIndices .size () == 1 ) {
5756 // Simple case, everything uses the same inference ID
5857 Map .Entry <String , List <String >> inferenceIdIndex = inferenceIdsIndices .entrySet ().iterator ().next ();
5958 String searchInferenceId = inferenceIdIndex .getKey ();
6059 List <String > indices = inferenceIdIndex .getValue ();
61- finalQueryBuilder = buildNestedQueryFromKnnVectorQuery (knnVectorQueryBuilder , indices , searchInferenceId );
60+ return buildNestedQueryFromKnnVectorQuery (knnVectorQueryBuilder , indices , searchInferenceId );
6261 } else {
6362 // Multiple inference IDs, construct a boolean query
64- finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds (knnVectorQueryBuilder , inferenceIdsIndices );
63+ return buildInferenceQueryWithMultipleInferenceIds (knnVectorQueryBuilder , inferenceIdsIndices );
6564 }
66- finalQueryBuilder .boost (queryBuilder .boost ());
67- finalQueryBuilder .queryName (queryBuilder .queryName ());
68- return finalQueryBuilder ;
6965 }
7066
7167 private QueryBuilder buildInferenceQueryWithMultipleInferenceIds (
@@ -106,8 +102,6 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
106102 )
107103 );
108104 }
109- boolQueryBuilder .boost (queryBuilder .boost ());
110- boolQueryBuilder .queryName (queryBuilder .queryName ());
111105 return boolQueryBuilder ;
112106 }
113107
@@ -124,17 +118,37 @@ private QueryBuilder buildNestedQueryFromKnnVectorQuery(
124118 }
125119 return QueryBuilders .nestedQuery (
126120 SemanticTextField .getChunksFieldName (filteredKnnVectorQueryBuilder .getFieldName ()),
127- new KnnVectorQueryBuilder (
128- filteredKnnVectorQueryBuilder ,
121+ buildNewKnnVectorQuery (
129122 SemanticTextField .getEmbeddingsFieldName (filteredKnnVectorQueryBuilder .getFieldName ()),
123+ filteredKnnVectorQueryBuilder ,
130124 queryVectorBuilder
131125 ),
132126 ScoreMode .Max
133- );
127+ ). queryName ( knnVectorQueryBuilder . queryName ()). boost ( knnVectorQueryBuilder . boost ()) ;
134128 }
135129
136130 private KnnVectorQueryBuilder addIndexFilterToKnnVectorQuery (Collection <String > indices , KnnVectorQueryBuilder original ) {
137- KnnVectorQueryBuilder copy = new KnnVectorQueryBuilder (original );
131+ KnnVectorQueryBuilder copy ;
132+ if (original .queryVectorBuilder () != null ) {
133+ copy = new KnnVectorQueryBuilder (
134+ original .getFieldName (),
135+ original .queryVectorBuilder (),
136+ original .k (),
137+ original .numCands (),
138+ original .getVectorSimilarity ()
139+ );
140+ } else {
141+ copy = new KnnVectorQueryBuilder (
142+ original .getFieldName (),
143+ original .queryVector (),
144+ original .k (),
145+ original .numCands (),
146+ original .rescoreVectorBuilder (),
147+ original .getVectorSimilarity ()
148+ );
149+ }
150+
151+ copy .addFilterQueries (original .filterQueries ());
138152 copy .addFilterQuery (new TermsQueryBuilder (IndexFieldMapper .NAME , indices ));
139153 return copy ;
140154 }
@@ -148,6 +162,35 @@ private TextEmbeddingQueryVectorBuilder getTextEmbeddingQueryBuilderFromQuery(Kn
148162 return (TextEmbeddingQueryVectorBuilder ) queryVectorBuilder ;
149163 }
150164
165+ private KnnVectorQueryBuilder buildNewKnnVectorQuery (
166+ String fieldName ,
167+ KnnVectorQueryBuilder original ,
168+ QueryVectorBuilder queryVectorBuilder
169+ ) {
170+ KnnVectorQueryBuilder newQueryBuilder ;
171+ if (original .queryVectorBuilder () != null ) {
172+ newQueryBuilder = new KnnVectorQueryBuilder (
173+ fieldName ,
174+ queryVectorBuilder ,
175+ original .k (),
176+ original .numCands (),
177+ original .getVectorSimilarity ()
178+ );
179+ } else {
180+ newQueryBuilder = new KnnVectorQueryBuilder (
181+ fieldName ,
182+ original .queryVector (),
183+ original .k (),
184+ original .numCands (),
185+ original .rescoreVectorBuilder (),
186+ original .getVectorSimilarity ()
187+ );
188+ }
189+
190+ newQueryBuilder .addFilterQueries (original .filterQueries ());
191+ return newQueryBuilder ;
192+ }
193+
151194 @ Override
152195 public String getQueryName () {
153196 return KnnVectorQueryBuilder .NAME ;
0 commit comments