Skip to content

Commit a989226

Browse files
committed
The testcase is working now
1 parent 1598701 commit a989226

File tree

2 files changed

+68
-103
lines changed

2 files changed

+68
-103
lines changed

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java

Lines changed: 67 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,14 @@
4646
import java.util.Arrays;
4747
import java.util.Collection;
4848
import java.util.List;
49-
import java.util.Map;
5049
import java.util.concurrent.atomic.AtomicInteger;
50+
import java.util.stream.Collectors;
5151

5252
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
5353
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
5454
import static org.hamcrest.CoreMatchers.is;
5555
import static org.hamcrest.Matchers.closeTo;
56+
import static org.hamcrest.Matchers.containsInAnyOrder;
5657
import static org.hamcrest.Matchers.containsString;
5758
import static org.hamcrest.Matchers.equalTo;
5859
import static org.hamcrest.Matchers.instanceOf;
@@ -842,166 +843,135 @@ public void testLinearRetrieverWithMinScoreValidation() {
842843
StandardRetrieverBuilder retriever1 = new StandardRetrieverBuilder(new MatchAllQueryBuilder());
843844
float[] weights = new float[] { 1.0f };
844845
ScoreNormalizer[] normalizers = LinearRetrieverBuilder.getDefaultNormalizers(1);
845-
846-
// Test negative minScore
847846
LinearRetrieverBuilder builder = new LinearRetrieverBuilder(
848847
List.of(new CompoundRetrieverBuilder.RetrieverSource(retriever1, null)),
849848
10,
850849
weights,
851850
normalizers
852851
);
853-
854-
// This should throw an exception
855-
IllegalArgumentException e = expectThrows(
856-
IllegalArgumentException.class,
857-
() -> builder.minScore(-0.1f)
858-
);
852+
853+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> builder.minScore(-0.1f));
859854
assertThat(e.getMessage(), equalTo("[min_score] must be greater than or equal to 0, was: -0.1"));
860-
861-
// Test valid minScore
862-
builder.minScore(0.1f); // This should not throw
855+
856+
builder.minScore(0.1f);
863857
assertThat(builder.minScore(), equalTo(0.1f));
864858
}
865859

866860
// public void testLinearRetrieverWithMinScoreScenarios() {
867861
// final int rankWindowSize = 10;
868862

869-
// // Define scores for TestRetrieverBuilder (documents exist from setupIndex)
870-
// TestRetrieverBuilder retrieverA = new TestRetrieverBuilder(
871-
// Map.of("doc_1", 10.0f, "doc_2", 8.0f, "doc_3", 6.0f, "doc_4", 4.0f)
872-
// );
873-
// TestRetrieverBuilder retrieverB = new TestRetrieverBuilder(
874-
// Map.of("doc_1", 1.0f, "doc_2", 3.0f, "doc_3", 5.0f, "doc_4", 2.0f)
875-
// );
863+
// // Setup test data
864+
// indexDoc(INDEX, "doc_1", TEXT_FIELD, "term1", "views.last30d", 10, "views.all", 100);
865+
// indexDoc(INDEX, "doc_2", TEXT_FIELD, "term1 term2", "views.last30d", 20, "views.all", 200);
866+
// indexDoc(INDEX, "doc_3", TEXT_FIELD, "term1 term2 term3", "views.last30d", 30, "views.all", 300);
867+
// indexDoc(INDEX, "doc_4", TEXT_FIELD, "term4", "views.last30d", 40, "views.all", 400);
868+
// refresh(INDEX);
876869

877-
// // Combined scores (weights {1.0f, 1.0f}, no normalization initially):
878-
// // doc_1: 10.0 + 1.0 = 11.0
879-
// // doc_2: 8.0 + 3.0 = 11.0
880-
// // doc_3: 6.0 + 5.0 = 11.0
881-
// // doc_4: 4.0 + 2.0 = 6.0
870+
// // Create retrievers with different scoring
871+
// StandardRetrieverBuilder retrieverA = new StandardRetrieverBuilder(QueryBuilders.termQuery(TEXT_FIELD, "term1").boost(10.0f));
872+
// StandardRetrieverBuilder retrieverB = new StandardRetrieverBuilder(QueryBuilders.termQuery(TEXT_FIELD, "term2").boost(1.0f));
882873

883874
// float[] weights = new float[] { 1.0f, 1.0f };
884875
// ScoreNormalizer[] identityNormalizers = LinearRetrieverBuilder.getDefaultNormalizers(2);
885876

886-
// // Scenario 1: minScore is null (not specified) - all docs returned
887-
// LinearRetrieverBuilder builderNullMinScore = new LinearRetrieverBuilder(
877+
// // Scenario 1: No min_score - all docs returned
878+
// LinearRetrieverBuilder builderNoMinScore = new LinearRetrieverBuilder(
888879
// List.of(
889880
// new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null),
890881
// new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null)
891882
// ),
892883
// rankWindowSize,
893884
// weights,
894-
// identityNormalizers,
895-
// null // Explicitly null
885+
// identityNormalizers
896886
// );
897-
// SearchSourceBuilder sourceNullMinScore = new SearchSourceBuilder().retriever(builderNullMinScore).size(rankWindowSize);
898-
// ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceNullMinScore), resp -> {
899-
// assertThat(resp.getHits().getTotalHits().value, equalTo(4L));
900-
// List<String> ids = Arrays.stream(resp.getHits().getHits()).map(h -> h.getId()).sorted().toList();
901-
// assertThat(ids, equalTo(List.of("doc_1", "doc_2", "doc_3", "doc_4")));
887+
888+
// SearchSourceBuilder sourceNoMinScore = new SearchSourceBuilder().retriever(builderNoMinScore).size(rankWindowSize);
889+
890+
// ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceNoMinScore), resp -> {
891+
// assertThat(resp.getHits().getTotalHits().value(), equalTo(3L)); // doc_1, doc_2, doc_3 match
892+
// assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_3")); // term1(10) + term2(1) = 11
893+
// assertThat(resp.getHits().getHits()[1].getId(), equalTo("doc_2")); // term1(10) + term2(1) = 11
894+
// assertThat(resp.getHits().getHits()[2].getId(), equalTo("doc_1")); // term1(10) = 10
902895
// });
903896

904-
// // Scenario 2: minScore = 0.0f - all docs returned (as all scores are > 0)
897+
// // Scenario 2: minScore = 0.0f - all matching docs returned (inclusive)
905898
// LinearRetrieverBuilder builderZeroMinScore = new LinearRetrieverBuilder(
906899
// List.of(
907900
// new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null),
908901
// new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null)
909902
// ),
910903
// rankWindowSize,
911904
// weights,
912-
// identityNormalizers,
913-
// 0.0f
914-
// );
905+
// identityNormalizers
906+
// ).minScore(0.0f);
907+
915908
// SearchSourceBuilder sourceZeroMinScore = new SearchSourceBuilder().retriever(builderZeroMinScore).size(rankWindowSize);
916-
// ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceZeroMinScore), resp -> {
917-
// assertThat(resp.getHits().getTotalHits().value, equalTo(4L));
918-
// });
919909

920-
// // Scenario 3: Basic filtering - minScore = 10.0f
921-
// // Expect: doc_1 (11.0), doc_2 (11.0), doc_3 (11.0). doc_4 (6.0) is filtered out.
910+
// ElasticsearchAssertions.assertResponse(
911+
// client().prepareSearch(INDEX).setSource(sourceZeroMinScore),
912+
// resp -> assertThat(resp.getHits().getTotalHits().value(), equalTo(3L))
913+
// );
914+
915+
// // Scenario 3: Basic filtering - minScore = 10.5f
922916
// LinearRetrieverBuilder builderFilterBasic = new LinearRetrieverBuilder(
923917
// List.of(
924918
// new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null),
925919
// new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null)
926920
// ),
927921
// rankWindowSize,
928922
// weights,
929-
// identityNormalizers,
930-
// 10.0f
931-
// );
923+
// identityNormalizers
924+
// ).minScore(10.5f);
925+
932926
// SearchSourceBuilder sourceFilterBasic = new SearchSourceBuilder().retriever(builderFilterBasic).size(rankWindowSize);
933-
// ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceFilterBasic), resp -> {
934-
// assertThat(resp.getHits().getTotalHits().value, equalTo(3L));
935-
// List<String> ids = Arrays.stream(resp.getHits().getHits()).map(h -> h.getId()).sorted().toList();
936-
// assertThat(ids, equalTo(List.of("doc_1", "doc_2", "doc_3")));
937-
// });
938927

939-
// // Scenario 4: Inclusive filtering - minScore = 6.0f
940-
// // Expect: doc_1 (11.0), doc_2 (11.0), doc_3 (11.0), doc_4 (6.0). doc_4 is included.
941-
// LinearRetrieverBuilder builderFilterInclusive = new LinearRetrieverBuilder(
942-
// List.of(
943-
// new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null),
944-
// new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null)
945-
// ),
946-
// rankWindowSize,
947-
// weights,
948-
// identityNormalizers,
949-
// 6.0f
950-
// );
951-
// SearchSourceBuilder sourceFilterInclusive = new SearchSourceBuilder().retriever(builderFilterInclusive).size(rankWindowSize);
952-
// ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceFilterInclusive), resp -> {
953-
// assertThat(resp.getHits().getTotalHits().value, equalTo(4L));
954-
// for (var hit : resp.getHits().getHits()) {
955-
// if (hit.getId().equals("doc_4")) assertThat((double)hit.getScore(), closeTo(6.0, 1e-5));
956-
// else assertThat((double)hit.getScore(), closeTo(11.0, 1e-5));
957-
// }
928+
// ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceFilterBasic), resp -> {
929+
// assertThat(resp.getHits().getTotalHits().value(), equalTo(2L)); // doc_2 and doc_3 have score 11.0
930+
// List<String> ids = Arrays.stream(resp.getHits().getHits()).map(h -> h.getId()).collect(Collectors.toList());
931+
// assertThat(ids, containsInAnyOrder("doc_2", "doc_3"));
958932
// });
959933

960-
// // Scenario 5: Filter all documents - minScore = 12.0f
934+
// // Scenario 4: Filter all documents - minScore = 20.0f
961935
// LinearRetrieverBuilder builderFilterAll = new LinearRetrieverBuilder(
962936
// List.of(
963937
// new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null),
964938
// new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null)
965939
// ),
966940
// rankWindowSize,
967941
// weights,
968-
// identityNormalizers,
969-
// 12.0f
970-
// );
942+
// identityNormalizers
943+
// ).minScore(20.0f);
944+
971945
// SearchSourceBuilder sourceFilterAll = new SearchSourceBuilder().retriever(builderFilterAll).size(rankWindowSize);
972-
// ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceFilterAll), resp -> {
973-
// assertThat(resp.getHits().getTotalHits().value, equalTo(0L));
974-
// });
975946

976-
// // Scenario 6: Interaction with MinMax Normalization
977-
// // Retriever A scores: doc_1 (10), doc_2 (8), doc_3 (6), doc_4 (4) -> Normalized: (10-4)/(10-4)=1, (8-4)/(10-4)=0.666, (6-4)/(10-4)=0.333, (4-4)/(10-4)=0
978-
// // Retriever B scores: doc_1 (1), doc_2 (3), doc_3 (5), doc_4 (2) -> Normalized: (1-1)/(5-1)=0, (3-1)/(5-1)=0.5, (5-1)/(5-1)=1, (2-1)/(5-1)=0.25
979-
// // Combined normalized scores (weights {1.0, 1.0}):
980-
// // doc_1: 1.0 + 0.0 = 1.0
981-
// // doc_2: 0.666 + 0.5 = 1.166
982-
// // doc_3: 0.333 + 1.0 = 1.333
983-
// // doc_4: 0.0 + 0.25 = 0.25
947+
// ElasticsearchAssertions.assertResponse(
948+
// client().prepareSearch(INDEX).setSource(sourceFilterAll),
949+
// resp -> assertThat(resp.getHits().getTotalHits().value(), equalTo(0L))
950+
// );
951+
952+
// // Scenario 5: Test with MinMax normalization
953+
// StandardRetrieverBuilder retrieverC = new StandardRetrieverBuilder(QueryBuilders.termQuery(TEXT_FIELD, "term1").boost(4.0f));
954+
// StandardRetrieverBuilder retrieverD = new StandardRetrieverBuilder(QueryBuilders.termQuery(TEXT_FIELD, "term2").boost(1.0f));
955+
984956
// ScoreNormalizer[] minMaxNormalizers = new ScoreNormalizer[] { MinMaxScoreNormalizer.INSTANCE, MinMaxScoreNormalizer.INSTANCE };
957+
985958
// LinearRetrieverBuilder builderWithNorm = new LinearRetrieverBuilder(
986959
// List.of(
987-
// new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null),
988-
// new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null)
960+
// new CompoundRetrieverBuilder.RetrieverSource(retrieverC, null),
961+
// new CompoundRetrieverBuilder.RetrieverSource(retrieverD, null)
989962
// ),
990963
// rankWindowSize,
991964
// weights,
992-
// minMaxNormalizers,
993-
// 1.1f // minScore after normalization
994-
// );
965+
// minMaxNormalizers
966+
// ).minScore(1.1f);
967+
995968
// SearchSourceBuilder sourceWithNorm = new SearchSourceBuilder().retriever(builderWithNorm).size(rankWindowSize);
969+
996970
// ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceWithNorm), resp -> {
997-
// // Expect doc_2 (1.166), doc_3 (1.333). doc_1 (1.0) and doc_4 (0.25) are filtered out.
998-
// assertThat(resp.getHits().getTotalHits().value, equalTo(2L));
999-
// List<String> ids = Arrays.stream(resp.getHits().getHits()).map(h -> h.getId()).sorted().toList();
1000-
// assertThat(ids, equalTo(List.of("doc_2", "doc_3")));
1001-
// for (var hit : resp.getHits().getHits()) {
1002-
// if (hit.getId().equals("doc_2")) assertThat((double)hit.getScore(), closeTo(1.166, 0.001));
1003-
// if (hit.getId().equals("doc_3")) assertThat((double)hit.getScore(), closeTo(1.333, 0.001));
1004-
// }
971+
// // With MinMax normalization, we expect doc_2 and doc_3 to have scores > 1.1
972+
// assertThat(resp.getHits().getTotalHits().value(), equalTo(2L));
973+
// List<String> ids = Arrays.stream(resp.getHits().getHits()).map(h -> h.getId()).collect(Collectors.toList());
974+
// assertThat(ids, containsInAnyOrder("doc_2", "doc_3"));
1005975
// });
1006976
// }
1007977
}

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,7 @@ public static LinearRetrieverBuilder fromXContent(XContentParser parser, Retriev
106106
}
107107

108108
LinearRetrieverBuilder(List<RetrieverSource> innerRetrievers, int rankWindowSize) {
109-
this(
110-
innerRetrievers,
111-
rankWindowSize,
112-
getDefaultWeight(innerRetrievers.size()),
113-
getDefaultNormalizers(innerRetrievers.size())
114-
);
109+
this(innerRetrievers, rankWindowSize, getDefaultWeight(innerRetrievers.size()), getDefaultNormalizers(innerRetrievers.size()));
115110
}
116111

117112
public LinearRetrieverBuilder(

0 commit comments

Comments
 (0)