|
25 | 25 | import java.util.Map; |
26 | 26 | import org.apache.lucene.index.VectorEncoding; |
27 | 27 | import org.apache.lucene.index.VectorSimilarityFunction; |
| 28 | +import org.apache.lucene.search.BooleanQuery; |
| 29 | +import org.apache.lucene.search.KnnByteVectorQuery; |
| 30 | +import org.apache.lucene.search.KnnFloatVectorQuery; |
| 31 | +import org.apache.lucene.search.PatienceKnnVectorQuery; |
| 32 | +import org.apache.lucene.search.Query; |
| 33 | +import org.apache.lucene.search.SeededKnnVectorQuery; |
| 34 | +import org.apache.lucene.search.knn.KnnSearchStrategy; |
28 | 35 | import org.apache.solr.client.solrj.request.JavaBinUpdateRequestCodec; |
29 | 36 | import org.apache.solr.client.solrj.request.UpdateRequest; |
30 | 37 | import org.apache.solr.common.SolrException; |
|
35 | 42 | import org.apache.solr.handler.loader.JavabinLoader; |
36 | 43 | import org.apache.solr.request.SolrQueryRequest; |
37 | 44 | import org.apache.solr.response.SolrQueryResponse; |
| 45 | +import org.apache.solr.search.neural.KnnQParser; |
38 | 46 | import org.apache.solr.update.CommitUpdateCommand; |
39 | 47 | import org.apache.solr.update.processor.UpdateRequestProcessor; |
40 | 48 | import org.apache.solr.update.processor.UpdateRequestProcessorChain; |
@@ -838,4 +846,283 @@ public void testIndexingViaJavaBin() throws Exception { |
838 | 846 | deleteCore(); |
839 | 847 | } |
840 | 848 | } |
| 849 | + |
| 850 | + @Test |
| 851 | + public void testFilteredSearchThreshold_floatNoThresholdInInput_shouldSetDefaultThreshold() |
| 852 | + throws Exception { |
| 853 | + try { |
| 854 | + Integer expectedThreshold = KnnSearchStrategy.DEFAULT_FILTERED_SEARCH_THRESHOLD; |
| 855 | + |
| 856 | + initCore("solrconfig-basic.xml", "schema-densevector.xml"); |
| 857 | + IndexSchema schema = h.getCore().getLatestSchema(); |
| 858 | + SchemaField vectorField = schema.getField("vector"); |
| 859 | + assertNotNull(vectorField); |
| 860 | + DenseVectorField type = (DenseVectorField) vectorField.getType(); |
| 861 | + KnnFloatVectorQuery vectorQuery = |
| 862 | + (KnnFloatVectorQuery) |
| 863 | + type.getKnnVectorQuery("vector", "[2, 1, 3, 4]", 3, null, null, null, null); |
| 864 | + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); |
| 865 | + Integer threshold = strategy.filteredSearchThreshold(); |
| 866 | + |
| 867 | + assertEquals(expectedThreshold, threshold); |
| 868 | + } finally { |
| 869 | + deleteCore(); |
| 870 | + } |
| 871 | + } |
| 872 | + |
| 873 | + @Test |
| 874 | + public void testFilteredSearchThreshold_floatThresholdInInput_shouldSetCustomThreshold() |
| 875 | + throws Exception { |
| 876 | + try { |
| 877 | + Integer expectedThreshold = 30; |
| 878 | + |
| 879 | + initCore("solrconfig-basic.xml", "schema-densevector.xml"); |
| 880 | + IndexSchema schema = h.getCore().getLatestSchema(); |
| 881 | + SchemaField vectorField = schema.getField("vector"); |
| 882 | + assertNotNull(vectorField); |
| 883 | + DenseVectorField type = (DenseVectorField) vectorField.getType(); |
| 884 | + KnnFloatVectorQuery vectorQuery = |
| 885 | + (KnnFloatVectorQuery) |
| 886 | + type.getKnnVectorQuery( |
| 887 | + "vector", "[2, 1, 3, 4]", 3, null, null, null, expectedThreshold); |
| 888 | + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); |
| 889 | + Integer threshold = strategy.filteredSearchThreshold(); |
| 890 | + |
| 891 | + assertEquals(expectedThreshold, threshold); |
| 892 | + } finally { |
| 893 | + deleteCore(); |
| 894 | + } |
| 895 | + } |
| 896 | + |
| 897 | + @Test |
| 898 | + public void testFilteredSearchThreshold_seededFloatThresholdInInput_shouldSetCustomThreshold() |
| 899 | + throws Exception { |
| 900 | + try { |
| 901 | + Query seedQuery = new BooleanQuery.Builder().build(); |
| 902 | + Integer expectedThreshold = 30; |
| 903 | + |
| 904 | + initCore("solrconfig-basic.xml", "schema-densevector.xml"); |
| 905 | + IndexSchema schema = h.getCore().getLatestSchema(); |
| 906 | + SchemaField vectorField = schema.getField("vector"); |
| 907 | + assertNotNull(vectorField); |
| 908 | + DenseVectorField type = (DenseVectorField) vectorField.getType(); |
| 909 | + SeededKnnVectorQuery vectorQuery = |
| 910 | + (SeededKnnVectorQuery) |
| 911 | + type.getKnnVectorQuery( |
| 912 | + "vector", "[2, 1, 3, 4]", 3, null, seedQuery, null, expectedThreshold); |
| 913 | + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); |
| 914 | + Integer threshold = strategy.filteredSearchThreshold(); |
| 915 | + |
| 916 | + assertEquals(expectedThreshold, threshold); |
| 917 | + } finally { |
| 918 | + deleteCore(); |
| 919 | + } |
| 920 | + } |
| 921 | + |
| 922 | + @Test |
| 923 | + public void |
| 924 | + testFilteredSearchThreshold_earlyTerminationFloatThresholdInInput_shouldSetCustomThreshold() |
| 925 | + throws Exception { |
| 926 | + try { |
| 927 | + KnnQParser.EarlyTerminationParams earlyTermination = |
| 928 | + new KnnQParser.EarlyTerminationParams(true, 0.995, 7); |
| 929 | + Integer expectedThreshold = 30; |
| 930 | + |
| 931 | + initCore("solrconfig-basic.xml", "schema-densevector.xml"); |
| 932 | + IndexSchema schema = h.getCore().getLatestSchema(); |
| 933 | + SchemaField vectorField = schema.getField("vector"); |
| 934 | + assertNotNull(vectorField); |
| 935 | + DenseVectorField type = (DenseVectorField) vectorField.getType(); |
| 936 | + PatienceKnnVectorQuery vectorQuery = |
| 937 | + (PatienceKnnVectorQuery) |
| 938 | + type.getKnnVectorQuery( |
| 939 | + "vector", "[2, 1, 3, 4]", 3, null, null, earlyTermination, expectedThreshold); |
| 940 | + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); |
| 941 | + Integer threshold = strategy.filteredSearchThreshold(); |
| 942 | + |
| 943 | + assertEquals(expectedThreshold, threshold); |
| 944 | + } finally { |
| 945 | + deleteCore(); |
| 946 | + } |
| 947 | + } |
| 948 | + |
| 949 | + @Test |
| 950 | + public void |
| 951 | + testFilteredSearchThreshold_seededAndEarlyTerminationFloatThresholdInInput_shouldSetCustomThreshold() |
| 952 | + throws Exception { |
| 953 | + try { |
| 954 | + Query seedQuery = new BooleanQuery.Builder().build(); |
| 955 | + KnnQParser.EarlyTerminationParams earlyTermination = |
| 956 | + new KnnQParser.EarlyTerminationParams(true, 0.995, 7); |
| 957 | + Integer expectedThreshold = 30; |
| 958 | + |
| 959 | + initCore("solrconfig-basic.xml", "schema-densevector.xml"); |
| 960 | + IndexSchema schema = h.getCore().getLatestSchema(); |
| 961 | + SchemaField vectorField = schema.getField("vector"); |
| 962 | + assertNotNull(vectorField); |
| 963 | + DenseVectorField type = (DenseVectorField) vectorField.getType(); |
| 964 | + PatienceKnnVectorQuery vectorQuery = |
| 965 | + (PatienceKnnVectorQuery) |
| 966 | + type.getKnnVectorQuery( |
| 967 | + "vector", |
| 968 | + "[2, 1, 3, 4]", |
| 969 | + 3, |
| 970 | + null, |
| 971 | + seedQuery, |
| 972 | + earlyTermination, |
| 973 | + expectedThreshold); |
| 974 | + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); |
| 975 | + Integer threshold = strategy.filteredSearchThreshold(); |
| 976 | + |
| 977 | + assertEquals(expectedThreshold, threshold); |
| 978 | + } finally { |
| 979 | + deleteCore(); |
| 980 | + } |
| 981 | + } |
| 982 | + |
| 983 | + @Test |
| 984 | + public void testFilteredSearchThreshold_byteNoThresholdInInput_shouldSetDefaultThreshold() |
| 985 | + throws Exception { |
| 986 | + try { |
| 987 | + Integer expectedThreshold = KnnSearchStrategy.DEFAULT_FILTERED_SEARCH_THRESHOLD; |
| 988 | + |
| 989 | + initCore("solrconfig-basic.xml", "schema-densevector.xml"); |
| 990 | + IndexSchema schema = h.getCore().getLatestSchema(); |
| 991 | + SchemaField vectorField = schema.getField("vector_byte_encoding"); |
| 992 | + assertNotNull(vectorField); |
| 993 | + DenseVectorField type = (DenseVectorField) vectorField.getType(); |
| 994 | + KnnByteVectorQuery vectorQuery = |
| 995 | + (KnnByteVectorQuery) |
| 996 | + type.getKnnVectorQuery( |
| 997 | + "vector_byte_encoding", "[2, 1, 3, 4]", 3, null, null, null, null); |
| 998 | + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); |
| 999 | + Integer threshold = strategy.filteredSearchThreshold(); |
| 1000 | + |
| 1001 | + assertEquals(expectedThreshold, threshold); |
| 1002 | + } finally { |
| 1003 | + deleteCore(); |
| 1004 | + } |
| 1005 | + } |
| 1006 | + |
| 1007 | + @Test |
| 1008 | + public void testFilteredSearchThreshold_byteThresholdInInput_shouldSetCustomThreshold() |
| 1009 | + throws Exception { |
| 1010 | + try { |
| 1011 | + Integer expectedThreshold = 30; |
| 1012 | + |
| 1013 | + initCore("solrconfig-basic.xml", "schema-densevector.xml"); |
| 1014 | + IndexSchema schema = h.getCore().getLatestSchema(); |
| 1015 | + SchemaField vectorField = schema.getField("vector_byte_encoding"); |
| 1016 | + assertNotNull(vectorField); |
| 1017 | + DenseVectorField type = (DenseVectorField) vectorField.getType(); |
| 1018 | + KnnByteVectorQuery vectorQuery = |
| 1019 | + (KnnByteVectorQuery) |
| 1020 | + type.getKnnVectorQuery( |
| 1021 | + "vector_byte_encoding", "[2, 1, 3, 4]", 3, null, null, null, expectedThreshold); |
| 1022 | + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); |
| 1023 | + Integer threshold = strategy.filteredSearchThreshold(); |
| 1024 | + |
| 1025 | + assertEquals(expectedThreshold, threshold); |
| 1026 | + } finally { |
| 1027 | + deleteCore(); |
| 1028 | + } |
| 1029 | + } |
| 1030 | + |
| 1031 | + @Test |
| 1032 | + public void testFilteredSearchThreshold_seededByteThresholdInInput_shouldSetCustomThreshold() |
| 1033 | + throws Exception { |
| 1034 | + try { |
| 1035 | + Query seedQuery = new BooleanQuery.Builder().build(); |
| 1036 | + Integer expectedThreshold = 30; |
| 1037 | + |
| 1038 | + initCore("solrconfig-basic.xml", "schema-densevector.xml"); |
| 1039 | + IndexSchema schema = h.getCore().getLatestSchema(); |
| 1040 | + SchemaField vectorField = schema.getField("vector_byte_encoding"); |
| 1041 | + assertNotNull(vectorField); |
| 1042 | + DenseVectorField type = (DenseVectorField) vectorField.getType(); |
| 1043 | + SeededKnnVectorQuery vectorQuery = |
| 1044 | + (SeededKnnVectorQuery) |
| 1045 | + type.getKnnVectorQuery( |
| 1046 | + "vector_byte_encoding", |
| 1047 | + "[2, 1, 3, 4]", |
| 1048 | + 3, |
| 1049 | + null, |
| 1050 | + seedQuery, |
| 1051 | + null, |
| 1052 | + expectedThreshold); |
| 1053 | + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); |
| 1054 | + Integer threshold = strategy.filteredSearchThreshold(); |
| 1055 | + |
| 1056 | + assertEquals(expectedThreshold, threshold); |
| 1057 | + } finally { |
| 1058 | + deleteCore(); |
| 1059 | + } |
| 1060 | + } |
| 1061 | + |
| 1062 | + @Test |
| 1063 | + public void |
| 1064 | + testFilteredSearchThreshold_earlyTerminationByteThresholdInInput_shouldSetCustomThreshold() |
| 1065 | + throws Exception { |
| 1066 | + try { |
| 1067 | + KnnQParser.EarlyTerminationParams earlyTermination = |
| 1068 | + new KnnQParser.EarlyTerminationParams(true, 0.995, 7); |
| 1069 | + Integer expectedThreshold = 30; |
| 1070 | + |
| 1071 | + initCore("solrconfig-basic.xml", "schema-densevector.xml"); |
| 1072 | + IndexSchema schema = h.getCore().getLatestSchema(); |
| 1073 | + SchemaField vectorField = schema.getField("vector_byte_encoding"); |
| 1074 | + assertNotNull(vectorField); |
| 1075 | + DenseVectorField type = (DenseVectorField) vectorField.getType(); |
| 1076 | + PatienceKnnVectorQuery vectorQuery = |
| 1077 | + (PatienceKnnVectorQuery) |
| 1078 | + type.getKnnVectorQuery( |
| 1079 | + "vector_byte_encoding", |
| 1080 | + "[2, 1, 3, 4]", |
| 1081 | + 3, |
| 1082 | + null, |
| 1083 | + null, |
| 1084 | + earlyTermination, |
| 1085 | + expectedThreshold); |
| 1086 | + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); |
| 1087 | + Integer threshold = strategy.filteredSearchThreshold(); |
| 1088 | + |
| 1089 | + assertEquals(expectedThreshold, threshold); |
| 1090 | + } finally { |
| 1091 | + deleteCore(); |
| 1092 | + } |
| 1093 | + } |
| 1094 | + |
| 1095 | + @Test |
| 1096 | + public void |
| 1097 | + testFilteredSearchThreshold_seededAndEarlyTerminationByteThresholdInInput_shouldSetCustomThreshold() |
| 1098 | + throws Exception { |
| 1099 | + try { |
| 1100 | + Query seedQuery = new BooleanQuery.Builder().build(); |
| 1101 | + KnnQParser.EarlyTerminationParams earlyTermination = |
| 1102 | + new KnnQParser.EarlyTerminationParams(true, 0.995, 7); |
| 1103 | + Integer expectedThreshold = 30; |
| 1104 | + |
| 1105 | + initCore("solrconfig-basic.xml", "schema-densevector.xml"); |
| 1106 | + IndexSchema schema = h.getCore().getLatestSchema(); |
| 1107 | + SchemaField vectorField = schema.getField("vector_byte_encoding"); |
| 1108 | + assertNotNull(vectorField); |
| 1109 | + DenseVectorField type = (DenseVectorField) vectorField.getType(); |
| 1110 | + PatienceKnnVectorQuery vectorQuery = |
| 1111 | + (PatienceKnnVectorQuery) |
| 1112 | + type.getKnnVectorQuery( |
| 1113 | + "vector_byte_encoding", |
| 1114 | + "[2, 1, 3, 4]", |
| 1115 | + 3, |
| 1116 | + null, |
| 1117 | + seedQuery, |
| 1118 | + earlyTermination, |
| 1119 | + expectedThreshold); |
| 1120 | + KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy(); |
| 1121 | + Integer threshold = strategy.filteredSearchThreshold(); |
| 1122 | + |
| 1123 | + assertEquals(expectedThreshold, threshold); |
| 1124 | + } finally { |
| 1125 | + deleteCore(); |
| 1126 | + } |
| 1127 | + } |
841 | 1128 | } |
0 commit comments