|
46 | 46 | import java.util.Arrays; |
47 | 47 | import java.util.Collection; |
48 | 48 | import java.util.List; |
49 | | -import java.util.Map; |
50 | 49 | import java.util.concurrent.atomic.AtomicInteger; |
| 50 | +import java.util.stream.Collectors; |
51 | 51 |
|
52 | 52 | import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; |
53 | 53 | import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; |
54 | 54 | import static org.hamcrest.CoreMatchers.is; |
55 | 55 | import static org.hamcrest.Matchers.closeTo; |
| 56 | +import static org.hamcrest.Matchers.containsInAnyOrder; |
56 | 57 | import static org.hamcrest.Matchers.containsString; |
57 | 58 | import static org.hamcrest.Matchers.equalTo; |
58 | 59 | import static org.hamcrest.Matchers.instanceOf; |
@@ -842,166 +843,135 @@ public void testLinearRetrieverWithMinScoreValidation() { |
842 | 843 | StandardRetrieverBuilder retriever1 = new StandardRetrieverBuilder(new MatchAllQueryBuilder()); |
843 | 844 | float[] weights = new float[] { 1.0f }; |
844 | 845 | ScoreNormalizer[] normalizers = LinearRetrieverBuilder.getDefaultNormalizers(1); |
845 | | - |
846 | | - // Test negative minScore |
847 | 846 | LinearRetrieverBuilder builder = new LinearRetrieverBuilder( |
848 | 847 | List.of(new CompoundRetrieverBuilder.RetrieverSource(retriever1, null)), |
849 | 848 | 10, |
850 | 849 | weights, |
851 | 850 | normalizers |
852 | 851 | ); |
853 | | - |
854 | | - // This should throw an exception |
855 | | - IllegalArgumentException e = expectThrows( |
856 | | - IllegalArgumentException.class, |
857 | | - () -> builder.minScore(-0.1f) |
858 | | - ); |
| 852 | + |
| 853 | + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> builder.minScore(-0.1f)); |
859 | 854 | assertThat(e.getMessage(), equalTo("[min_score] must be greater than or equal to 0, was: -0.1")); |
860 | | - |
861 | | - // Test valid minScore |
862 | | - builder.minScore(0.1f); // This should not throw |
| 855 | + |
| 856 | + builder.minScore(0.1f); |
863 | 857 | assertThat(builder.minScore(), equalTo(0.1f)); |
864 | 858 | } |
865 | 859 |
|
866 | 860 | // public void testLinearRetrieverWithMinScoreScenarios() { |
867 | 861 | // final int rankWindowSize = 10; |
868 | 862 |
|
869 | | - // // Define scores for TestRetrieverBuilder (documents exist from setupIndex) |
870 | | - // TestRetrieverBuilder retrieverA = new TestRetrieverBuilder( |
871 | | - // Map.of("doc_1", 10.0f, "doc_2", 8.0f, "doc_3", 6.0f, "doc_4", 4.0f) |
872 | | - // ); |
873 | | - // TestRetrieverBuilder retrieverB = new TestRetrieverBuilder( |
874 | | - // Map.of("doc_1", 1.0f, "doc_2", 3.0f, "doc_3", 5.0f, "doc_4", 2.0f) |
875 | | - // ); |
| 863 | + // // Setup test data |
| 864 | + // indexDoc(INDEX, "doc_1", TEXT_FIELD, "term1", "views.last30d", 10, "views.all", 100); |
| 865 | + // indexDoc(INDEX, "doc_2", TEXT_FIELD, "term1 term2", "views.last30d", 20, "views.all", 200); |
| 866 | + // indexDoc(INDEX, "doc_3", TEXT_FIELD, "term1 term2 term3", "views.last30d", 30, "views.all", 300); |
| 867 | + // indexDoc(INDEX, "doc_4", TEXT_FIELD, "term4", "views.last30d", 40, "views.all", 400); |
| 868 | + // refresh(INDEX); |
876 | 869 |
|
877 | | - // // Combined scores (weights {1.0f, 1.0f}, no normalization initially): |
878 | | - // // doc_1: 10.0 + 1.0 = 11.0 |
879 | | - // // doc_2: 8.0 + 3.0 = 11.0 |
880 | | - // // doc_3: 6.0 + 5.0 = 11.0 |
881 | | - // // doc_4: 4.0 + 2.0 = 6.0 |
| 870 | + // // Create retrievers with different scoring |
| 871 | + // StandardRetrieverBuilder retrieverA = new StandardRetrieverBuilder(QueryBuilders.termQuery(TEXT_FIELD, "term1").boost(10.0f)); |
| 872 | + // StandardRetrieverBuilder retrieverB = new StandardRetrieverBuilder(QueryBuilders.termQuery(TEXT_FIELD, "term2").boost(1.0f)); |
882 | 873 |
|
883 | 874 | // float[] weights = new float[] { 1.0f, 1.0f }; |
884 | 875 | // ScoreNormalizer[] identityNormalizers = LinearRetrieverBuilder.getDefaultNormalizers(2); |
885 | 876 |
|
886 | | - // // Scenario 1: minScore is null (not specified) - all docs returned |
887 | | - // LinearRetrieverBuilder builderNullMinScore = new LinearRetrieverBuilder( |
| 877 | + // // Scenario 1: No min_score - all docs returned |
| 878 | + // LinearRetrieverBuilder builderNoMinScore = new LinearRetrieverBuilder( |
888 | 879 | // List.of( |
889 | 880 | // new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null), |
890 | 881 | // new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null) |
891 | 882 | // ), |
892 | 883 | // rankWindowSize, |
893 | 884 | // weights, |
894 | | - // identityNormalizers, |
895 | | - // null // Explicitly null |
| 885 | + // identityNormalizers |
896 | 886 | // ); |
897 | | - // SearchSourceBuilder sourceNullMinScore = new SearchSourceBuilder().retriever(builderNullMinScore).size(rankWindowSize); |
898 | | - // ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceNullMinScore), resp -> { |
899 | | - // assertThat(resp.getHits().getTotalHits().value, equalTo(4L)); |
900 | | - // List<String> ids = Arrays.stream(resp.getHits().getHits()).map(h -> h.getId()).sorted().toList(); |
901 | | - // assertThat(ids, equalTo(List.of("doc_1", "doc_2", "doc_3", "doc_4"))); |
| 887 | + |
| 888 | + // SearchSourceBuilder sourceNoMinScore = new SearchSourceBuilder().retriever(builderNoMinScore).size(rankWindowSize); |
| 889 | + |
| 890 | + // ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceNoMinScore), resp -> { |
| 891 | + // assertThat(resp.getHits().getTotalHits().value(), equalTo(3L)); // doc_1, doc_2, doc_3 match |
| 892 | + // assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_3")); // term1(10) + term2(1) = 11 |
| 893 | + // assertThat(resp.getHits().getHits()[1].getId(), equalTo("doc_2")); // term1(10) + term2(1) = 11 |
| 894 | + // assertThat(resp.getHits().getHits()[2].getId(), equalTo("doc_1")); // term1(10) = 10 |
902 | 895 | // }); |
903 | 896 |
|
904 | | - // // Scenario 2: minScore = 0.0f - all docs returned (as all scores are > 0) |
| 897 | + // // Scenario 2: minScore = 0.0f - all matching docs returned (inclusive) |
905 | 898 | // LinearRetrieverBuilder builderZeroMinScore = new LinearRetrieverBuilder( |
906 | 899 | // List.of( |
907 | 900 | // new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null), |
908 | 901 | // new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null) |
909 | 902 | // ), |
910 | 903 | // rankWindowSize, |
911 | 904 | // weights, |
912 | | - // identityNormalizers, |
913 | | - // 0.0f |
914 | | - // ); |
| 905 | + // identityNormalizers |
| 906 | + // ).minScore(0.0f); |
| 907 | + |
915 | 908 | // SearchSourceBuilder sourceZeroMinScore = new SearchSourceBuilder().retriever(builderZeroMinScore).size(rankWindowSize); |
916 | | - // ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceZeroMinScore), resp -> { |
917 | | - // assertThat(resp.getHits().getTotalHits().value, equalTo(4L)); |
918 | | - // }); |
919 | 909 |
|
920 | | - // // Scenario 3: Basic filtering - minScore = 10.0f |
921 | | - // // Expect: doc_1 (11.0), doc_2 (11.0), doc_3 (11.0). doc_4 (6.0) is filtered out. |
| 910 | + // ElasticsearchAssertions.assertResponse( |
| 911 | + // client().prepareSearch(INDEX).setSource(sourceZeroMinScore), |
| 912 | + // resp -> assertThat(resp.getHits().getTotalHits().value(), equalTo(3L)) |
| 913 | + // ); |
| 914 | + |
| 915 | + // // Scenario 3: Basic filtering - minScore = 10.5f |
922 | 916 | // LinearRetrieverBuilder builderFilterBasic = new LinearRetrieverBuilder( |
923 | 917 | // List.of( |
924 | 918 | // new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null), |
925 | 919 | // new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null) |
926 | 920 | // ), |
927 | 921 | // rankWindowSize, |
928 | 922 | // weights, |
929 | | - // identityNormalizers, |
930 | | - // 10.0f |
931 | | - // ); |
| 923 | + // identityNormalizers |
| 924 | + // ).minScore(10.5f); |
| 925 | + |
932 | 926 | // SearchSourceBuilder sourceFilterBasic = new SearchSourceBuilder().retriever(builderFilterBasic).size(rankWindowSize); |
933 | | - // ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceFilterBasic), resp -> { |
934 | | - // assertThat(resp.getHits().getTotalHits().value, equalTo(3L)); |
935 | | - // List<String> ids = Arrays.stream(resp.getHits().getHits()).map(h -> h.getId()).sorted().toList(); |
936 | | - // assertThat(ids, equalTo(List.of("doc_1", "doc_2", "doc_3"))); |
937 | | - // }); |
938 | 927 |
|
939 | | - // // Scenario 4: Inclusive filtering - minScore = 6.0f |
940 | | - // // Expect: doc_1 (11.0), doc_2 (11.0), doc_3 (11.0), doc_4 (6.0). doc_4 is included. |
941 | | - // LinearRetrieverBuilder builderFilterInclusive = new LinearRetrieverBuilder( |
942 | | - // List.of( |
943 | | - // new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null), |
944 | | - // new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null) |
945 | | - // ), |
946 | | - // rankWindowSize, |
947 | | - // weights, |
948 | | - // identityNormalizers, |
949 | | - // 6.0f |
950 | | - // ); |
951 | | - // SearchSourceBuilder sourceFilterInclusive = new SearchSourceBuilder().retriever(builderFilterInclusive).size(rankWindowSize); |
952 | | - // ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceFilterInclusive), resp -> { |
953 | | - // assertThat(resp.getHits().getTotalHits().value, equalTo(4L)); |
954 | | - // for (var hit : resp.getHits().getHits()) { |
955 | | - // if (hit.getId().equals("doc_4")) assertThat((double)hit.getScore(), closeTo(6.0, 1e-5)); |
956 | | - // else assertThat((double)hit.getScore(), closeTo(11.0, 1e-5)); |
957 | | - // } |
| 928 | + // ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceFilterBasic), resp -> { |
| 929 | + // assertThat(resp.getHits().getTotalHits().value(), equalTo(2L)); // doc_2 and doc_3 have score 11.0 |
| 930 | + // List<String> ids = Arrays.stream(resp.getHits().getHits()).map(h -> h.getId()).collect(Collectors.toList()); |
| 931 | + // assertThat(ids, containsInAnyOrder("doc_2", "doc_3")); |
958 | 932 | // }); |
959 | 933 |
|
960 | | - // // Scenario 5: Filter all documents - minScore = 12.0f |
| 934 | + // // Scenario 4: Filter all documents - minScore = 20.0f |
961 | 935 | // LinearRetrieverBuilder builderFilterAll = new LinearRetrieverBuilder( |
962 | 936 | // List.of( |
963 | 937 | // new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null), |
964 | 938 | // new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null) |
965 | 939 | // ), |
966 | 940 | // rankWindowSize, |
967 | 941 | // weights, |
968 | | - // identityNormalizers, |
969 | | - // 12.0f |
970 | | - // ); |
| 942 | + // identityNormalizers |
| 943 | + // ).minScore(20.0f); |
| 944 | + |
971 | 945 | // SearchSourceBuilder sourceFilterAll = new SearchSourceBuilder().retriever(builderFilterAll).size(rankWindowSize); |
972 | | - // ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceFilterAll), resp -> { |
973 | | - // assertThat(resp.getHits().getTotalHits().value, equalTo(0L)); |
974 | | - // }); |
975 | 946 |
|
976 | | - // // Scenario 6: Interaction with MinMax Normalization |
977 | | - // // Retriever A scores: doc_1 (10), doc_2 (8), doc_3 (6), doc_4 (4) -> Normalized: (10-4)/(10-4)=1, (8-4)/(10-4)=0.666, (6-4)/(10-4)=0.333, (4-4)/(10-4)=0 |
978 | | - // // Retriever B scores: doc_1 (1), doc_2 (3), doc_3 (5), doc_4 (2) -> Normalized: (1-1)/(5-1)=0, (3-1)/(5-1)=0.5, (5-1)/(5-1)=1, (2-1)/(5-1)=0.25 |
979 | | - // // Combined normalized scores (weights {1.0, 1.0}): |
980 | | - // // doc_1: 1.0 + 0.0 = 1.0 |
981 | | - // // doc_2: 0.666 + 0.5 = 1.166 |
982 | | - // // doc_3: 0.333 + 1.0 = 1.333 |
983 | | - // // doc_4: 0.0 + 0.25 = 0.25 |
| 947 | + // ElasticsearchAssertions.assertResponse( |
| 948 | + // client().prepareSearch(INDEX).setSource(sourceFilterAll), |
| 949 | + // resp -> assertThat(resp.getHits().getTotalHits().value(), equalTo(0L)) |
| 950 | + // ); |
| 951 | + |
| 952 | + // // Scenario 5: Test with MinMax normalization |
| 953 | + // StandardRetrieverBuilder retrieverC = new StandardRetrieverBuilder(QueryBuilders.termQuery(TEXT_FIELD, "term1").boost(4.0f)); |
| 954 | + // StandardRetrieverBuilder retrieverD = new StandardRetrieverBuilder(QueryBuilders.termQuery(TEXT_FIELD, "term2").boost(1.0f)); |
| 955 | + |
984 | 956 | // ScoreNormalizer[] minMaxNormalizers = new ScoreNormalizer[] { MinMaxScoreNormalizer.INSTANCE, MinMaxScoreNormalizer.INSTANCE }; |
| 957 | + |
985 | 958 | // LinearRetrieverBuilder builderWithNorm = new LinearRetrieverBuilder( |
986 | 959 | // List.of( |
987 | | - // new CompoundRetrieverBuilder.RetrieverSource(retrieverA, null), |
988 | | - // new CompoundRetrieverBuilder.RetrieverSource(retrieverB, null) |
| 960 | + // new CompoundRetrieverBuilder.RetrieverSource(retrieverC, null), |
| 961 | + // new CompoundRetrieverBuilder.RetrieverSource(retrieverD, null) |
989 | 962 | // ), |
990 | 963 | // rankWindowSize, |
991 | 964 | // weights, |
992 | | - // minMaxNormalizers, |
993 | | - // 1.1f // minScore after normalization |
994 | | - // ); |
| 965 | + // minMaxNormalizers |
| 966 | + // ).minScore(1.1f); |
| 967 | + |
995 | 968 | // SearchSourceBuilder sourceWithNorm = new SearchSourceBuilder().retriever(builderWithNorm).size(rankWindowSize); |
| 969 | + |
996 | 970 | // ElasticsearchAssertions.assertResponse(client().prepareSearch(INDEX).setSource(sourceWithNorm), resp -> { |
997 | | - // // Expect doc_2 (1.166), doc_3 (1.333). doc_1 (1.0) and doc_4 (0.25) are filtered out. |
998 | | - // assertThat(resp.getHits().getTotalHits().value, equalTo(2L)); |
999 | | - // List<String> ids = Arrays.stream(resp.getHits().getHits()).map(h -> h.getId()).sorted().toList(); |
1000 | | - // assertThat(ids, equalTo(List.of("doc_2", "doc_3"))); |
1001 | | - // for (var hit : resp.getHits().getHits()) { |
1002 | | - // if (hit.getId().equals("doc_2")) assertThat((double)hit.getScore(), closeTo(1.166, 0.001)); |
1003 | | - // if (hit.getId().equals("doc_3")) assertThat((double)hit.getScore(), closeTo(1.333, 0.001)); |
1004 | | - // } |
| 971 | + // // With MinMax normalization, we expect doc_2 and doc_3 to have scores > 1.1 |
| 972 | + // assertThat(resp.getHits().getTotalHits().value(), equalTo(2L)); |
| 973 | + // List<String> ids = Arrays.stream(resp.getHits().getHits()).map(h -> h.getId()).collect(Collectors.toList()); |
| 974 | + // assertThat(ids, containsInAnyOrder("doc_2", "doc_3")); |
1005 | 975 | // }); |
1006 | 976 | // } |
1007 | 977 | } |
0 commit comments