1+ package org .elasticsearch .xpack .rank .linear ;
2+
3+ import org .apache .lucene .search .ScoreDoc ;
4+ import org .elasticsearch .search .rank .RankDoc ;
5+ import org .elasticsearch .search .retriever .CompoundRetrieverBuilder .RetrieverSource ;
6+ import org .elasticsearch .search .retriever .TestRetrieverBuilder ;
7+ import org .elasticsearch .test .ESTestCase ;
8+
9+ import java .io .IOException ;
10+ import java .util .List ;
11+ import java .util .stream .Stream ;
12+
13+ public class LinearRetrieverBuilderTests extends ESTestCase {
14+
15+ public void testCombineInnerRetrieverResultsCalculatesWeightedSum () throws IOException {
16+ List <RetrieverSource > sources = List .of (
17+ new RetrieverSource (new TestRetrieverBuilder ("r1" ), null ),
18+ new RetrieverSource (new TestRetrieverBuilder ("r2" ), null )
19+ );
20+ int rankWindowSize = 5 ;
21+ float [] weights = new float [] { 1.0f , 2.0f };
22+ ScoreNormalizer [] normalizers = new ScoreNormalizer [] { IdentityScoreNormalizer .INSTANCE , IdentityScoreNormalizer .INSTANCE };
23+ float minScore = 0f ;
24+
25+ LinearRetrieverBuilder builder = new LinearRetrieverBuilder (
26+ sources , rankWindowSize , weights , normalizers , minScore
27+ );
28+
29+ ScoreDoc [] left = new ScoreDoc [] { new ScoreDoc (5 , 1.0f , 0 ), new ScoreDoc (6 , 2.0f , 0 ) };
30+ ScoreDoc [] right = new ScoreDoc [] { new ScoreDoc (5 , 3.0f , 0 ), new ScoreDoc (6 , 1.0f , 0 ), new ScoreDoc (7 , 0.5f , 0 ) };
31+
32+ RankDoc [] results = builder .combineInnerRetrieverResults (List .of (left , right ), false );
33+
34+ assertEquals ("Should have 3 combined docs" , 3 , results .length );
35+
36+ assertEquals ("Doc 5 score" , 7.0f , results [0 ].score , 0.001f );
37+ assertEquals ("Doc 5 ID" , 5 , results [0 ].doc );
38+ assertEquals ("Doc 5 rank" , 1 , results [0 ].rank );
39+
40+ assertEquals ("Doc 6 score" , 4.0f , results [1 ].score , 0.001f );
41+ assertEquals ("Doc 6 ID" , 6 , results [1 ].doc );
42+ assertEquals ("Doc 6 rank" , 2 , results [1 ].rank );
43+
44+ assertEquals ("Doc 7 score" , 1.0f , results [2 ].score , 0.001f );
45+ assertEquals ("Doc 7 ID" , 7 , results [2 ].doc );
46+ assertEquals ("Doc 7 rank" , 3 , results [2 ].rank );
47+ }
48+
49+ public void testCombineAndFilterWithMinScore () throws IOException {
50+ List <RetrieverSource > sources = List .of (
51+ new RetrieverSource (new TestRetrieverBuilder ("r1" ), null ),
52+ new RetrieverSource (new TestRetrieverBuilder ("r2" ), null )
53+ );
54+ int rankWindowSize = 5 ;
55+ float [] weights = new float [] { 1.0f , 1.0f };
56+ ScoreNormalizer [] normalizers = new ScoreNormalizer [] { IdentityScoreNormalizer .INSTANCE , IdentityScoreNormalizer .INSTANCE };
57+ final float minScoreThreshold = 1.5f ;
58+
59+ LinearRetrieverBuilder builder = new LinearRetrieverBuilder (
60+ sources , rankWindowSize , weights , normalizers , minScoreThreshold
61+ );
62+ assertEquals (minScoreThreshold , builder .getMinScore (), 0f );
63+
64+ ScoreDoc [] left = new ScoreDoc [] { new ScoreDoc (0 , 1.0f , 0 ), new ScoreDoc (1 , 0.8f , 0 ) };
65+ ScoreDoc [] right = new ScoreDoc [] { new ScoreDoc (0 , 2.0f , 0 ), new ScoreDoc (1 , 0.6f , 0 ), new ScoreDoc (2 , 1.8f , 0 ) };
66+
67+ RankDoc [] combinedRankDocs = builder .combineInnerRetrieverResults (List .of (left , right ), false );
68+ assertEquals ("Combined docs before filtering" , 3 , combinedRankDocs .length );
69+
70+ List <RankDoc > filteredDocs = Stream .of (combinedRankDocs )
71+ .filter (rd -> rd .score >= builder .getMinScore ())
72+ .sorted ()
73+ .toList ();
74+
75+ assertEquals ("Filtered docs count" , 2 , filteredDocs .size ());
76+
77+ boolean foundDoc0 = false ;
78+ boolean foundDoc2 = false ;
79+ for (RankDoc scoreDoc : filteredDocs ) {
80+ assertTrue ("Score should be >= minScore" , scoreDoc .score >= minScoreThreshold );
81+ if (scoreDoc .doc == 0 ) {
82+ assertEquals ("Doc 0 score" , 3.0f , scoreDoc .score , 0.001f );
83+ foundDoc0 = true ;
84+ } else if (scoreDoc .doc == 2 ) {
85+ assertEquals ("Doc 2 score" , 1.8f , scoreDoc .score , 0.001f );
86+ foundDoc2 = true ;
87+ } else {
88+ fail ("Unexpected document ID returned: " + scoreDoc .doc + " (should have been filtered)" );
89+ }
90+ }
91+ assertTrue ("Document 0 should have been found" , foundDoc0 );
92+ assertTrue ("Document 2 should have been found" , foundDoc2 );
93+ }
94+ }
0 commit comments