77
88package org .elasticsearch .xpack .rank .rrf ;
99
10+ import org .elasticsearch .action .MockResolvedIndices ;
11+ import org .elasticsearch .action .OriginalIndices ;
12+ import org .elasticsearch .action .ResolvedIndices ;
13+ import org .elasticsearch .action .support .IndicesOptions ;
14+ import org .elasticsearch .cluster .metadata .IndexMetadata ;
15+ import org .elasticsearch .cluster .metadata .InferenceFieldMetadata ;
1016import org .elasticsearch .common .bytes .BytesArray ;
1117import org .elasticsearch .common .settings .Settings ;
18+ import org .elasticsearch .index .Index ;
19+ import org .elasticsearch .index .IndexVersion ;
20+ import org .elasticsearch .index .query .MatchQueryBuilder ;
21+ import org .elasticsearch .index .query .MultiMatchQueryBuilder ;
1222import org .elasticsearch .index .query .QueryRewriteContext ;
1323import org .elasticsearch .search .SearchModule ;
1424import org .elasticsearch .search .builder .PointInTimeBuilder ;
1525import org .elasticsearch .search .builder .SearchSourceBuilder ;
26+ import org .elasticsearch .search .retriever .CompoundRetrieverBuilder ;
1627import org .elasticsearch .search .retriever .RetrieverBuilder ;
1728import org .elasticsearch .search .retriever .RetrieverParserContext ;
29+ import org .elasticsearch .search .retriever .StandardRetrieverBuilder ;
1830import org .elasticsearch .test .ESTestCase ;
1931import org .elasticsearch .xcontent .NamedXContentRegistry ;
2032import org .elasticsearch .xcontent .ParseField ;
2335
2436import java .io .IOException ;
2537import java .util .List ;
38+ import java .util .Map ;
39+
40+ import static org .elasticsearch .search .retriever .CompoundRetrieverBuilder .convertToRetrieverSource ;
2641
2742/** Tests for the rrf retriever. */
2843public class RRFRetrieverBuilderTests extends ESTestCase {
@@ -66,6 +81,37 @@ public void testRetrieverExtractionErrors() throws IOException {
6681 }
6782 }
6883
84+ public void testSimplifiedParamsRewrite () {
85+ final String indexName = "test-index" ;
86+ final List <String > testInferenceFields = List .of ("semantic_field_1" , "semantic_field_2" );
87+ final ResolvedIndices resolvedIndices = createMockResolvedIndices (indexName , testInferenceFields );
88+ final QueryRewriteContext queryRewriteContext = new QueryRewriteContext (
89+ parserConfig (),
90+ null ,
91+ null ,
92+ resolvedIndices ,
93+ new PointInTimeBuilder (new BytesArray ("pitid" )),
94+ null
95+ );
96+
97+ RRFRetrieverBuilder rrfRetrieverBuilder = new RRFRetrieverBuilder (
98+ null ,
99+ List .of ("field_1" , "field_2" , "semantic_field_1" , "semantic_field_2" ),
100+ "foo" ,
101+ 10 ,
102+ RRFRetrieverBuilder .DEFAULT_RANK_CONSTANT
103+ );
104+ assertSimplifiedParamsRewrite (
105+ rrfRetrieverBuilder ,
106+ queryRewriteContext ,
107+ Map .of ("field_1" , 1.0f , "field_2" , 1.0f ),
108+ Map .of ("semantic_field_1" , 1.0f , "semantic_field_2" , 1.0f ),
109+ "foo"
110+ );
111+
112+ // TODO: Test with wildcard resolution
113+ }
114+
69115 @ Override
70116 protected NamedXContentRegistry xContentRegistry () {
71117 List <NamedXContentRegistry .Entry > entries = new SearchModule (Settings .EMPTY , List .of ()).getNamedXContents ();
@@ -86,4 +132,69 @@ protected NamedXContentRegistry xContentRegistry() {
86132 );
87133 return new NamedXContentRegistry (entries );
88134 }
135+
136+ private static ResolvedIndices createMockResolvedIndices (String indexName , List <String > inferenceFields ) {
137+ Index index = new Index (indexName , randomAlphaOfLength (10 ));
138+ IndexMetadata .Builder indexMetadataBuilder = IndexMetadata .builder (index .getName ())
139+ .settings (
140+ Settings .builder ()
141+ .put (IndexMetadata .SETTING_VERSION_CREATED , IndexVersion .current ())
142+ .put (IndexMetadata .SETTING_INDEX_UUID , index .getUUID ())
143+ )
144+ .numberOfShards (1 )
145+ .numberOfReplicas (0 );
146+
147+ for (String inferenceField : inferenceFields ) {
148+ indexMetadataBuilder .putInferenceField (
149+ new InferenceFieldMetadata (inferenceField , randomAlphaOfLengthBetween (3 , 5 ), new String [] { inferenceField }, null )
150+ );
151+ }
152+
153+ return new MockResolvedIndices (
154+ Map .of (),
155+ new OriginalIndices (new String [] { indexName }, IndicesOptions .DEFAULT ),
156+ Map .of (index , indexMetadataBuilder .build ())
157+ );
158+ }
159+
160+ private static void assertSimplifiedParamsRewrite (
161+ RRFRetrieverBuilder retriever ,
162+ QueryRewriteContext ctx ,
163+ Map <String , Float > expectedNonInferenceFields ,
164+ Map <String , Float > expectedInferenceFields ,
165+ String expectedQuery
166+ ) {
167+ List <CompoundRetrieverBuilder .RetrieverSource > expectedInnerRetrievers = List .of (
168+ convertToRetrieverSource (
169+ new StandardRetrieverBuilder (
170+ new MultiMatchQueryBuilder (expectedQuery ).type (MultiMatchQueryBuilder .Type .MOST_FIELDS )
171+ .fields (expectedNonInferenceFields )
172+ )
173+ ),
174+ convertToRetrieverSource (
175+ new RRFRetrieverBuilder (
176+ expectedInferenceFields .entrySet ()
177+ .stream ()
178+ .map (e -> {
179+ if (e .getValue () != 1.0f ) {
180+ throw new IllegalArgumentException ("Cannot apply per-field weights in RRF" );
181+ }
182+ return convertToRetrieverSource (new StandardRetrieverBuilder (new MatchQueryBuilder (e .getKey (), expectedQuery )));
183+ })
184+ .toList (),
185+ retriever .rankWindowSize (),
186+ retriever .rankConstant ()
187+ )
188+ )
189+ );
190+ RRFRetrieverBuilder expectedRewritten = new RRFRetrieverBuilder (
191+ expectedInnerRetrievers ,
192+ retriever .rankWindowSize (),
193+ retriever .rankConstant ()
194+ );
195+
196+ RRFRetrieverBuilder rewritten = retriever .doRewrite (ctx );
197+ assertNotSame (retriever , rewritten );
198+ assertEquals (expectedRewritten , rewritten );
199+ }
89200}
0 commit comments