1717import org .apache .lucene .index .LeafReader ;
1818import org .apache .lucene .search .BooleanClause ;
1919import org .apache .lucene .search .BooleanQuery ;
20+ import org .apache .lucene .search .BoostQuery ;
2021import org .apache .lucene .search .IndexSearcher ;
2122import org .apache .lucene .search .Query ;
2223import org .apache .lucene .store .Directory ;
3940import org .elasticsearch .search .lookup .Source ;
4041import org .elasticsearch .search .vectors .SparseVectorQueryWrapper ;
4142import org .elasticsearch .test .index .IndexVersionUtils ;
43+ import org .elasticsearch .test .junit .annotations .TestLogging ;
4244import org .elasticsearch .xcontent .ToXContent ;
4345import org .elasticsearch .xcontent .XContentBuilder ;
4446import org .elasticsearch .xcontent .XContentParseException ;
4850import org .junit .AssumptionViolatedException ;
4951
5052import java .io .IOException ;
53+ import java .util .ArrayList ;
5154import java .util .Arrays ;
5255import java .util .Collection ;
5356import java .util .EnumSet ;
5659import java .util .Map ;
5760import java .util .Set ;
5861import java .util .TreeMap ;
62+ import java .util .stream .Collectors ;
63+ import java .util .stream .Stream ;
5964
6065import static org .elasticsearch .index .IndexVersions .SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT ;
6166import static org .elasticsearch .index .IndexVersions .UPGRADE_TO_LUCENE_10_0_0 ;
@@ -73,25 +78,27 @@ public class SparseVectorFieldMapperTests extends SyntheticVectorsMapperTestCase
7378 public static final float STRICT_TOKENS_WEIGHT_THRESHOLD = 0.5f ;
7479 public static final float STRICT_TOKENS_FREQ_RATIO_THRESHOLD = 1 ;
7580
76- @ Override
77- protected Object getSampleValueForDocument () {
78- // randomMap(1, 5, () -> Tuple.tuple(randomAlphaOfLengthBetween(5, 10), Float.valueOf(randomIntBetween(1, 127))))
79- Map <String , Float > map = new TreeMap <>();
80-
81- // High weight tokens - low freq (should survive strict pruning)
82- map .put ("rare1" , (float ) randomIntBetween (1 , 127 ));
83- map .put ("rare2" , (float ) randomIntBetween (1 , 127 ));
81+ private static final Map <String , Float > COMMON_TOKENS = Map .of (
82+ "common1_drop_default" , 0.1f ,
83+ "common2_drop_default" , 0.1f ,
84+ "common3_drop_default" , 0.1f
85+ );
8486
85- // Medium weight - medium freq (half of them should survive strict pruning)
86- map .put ("medium_freq" , (float ) randomIntBetween (1 , 127 ));
87- map .put ("medium_freq2" , (float ) randomIntBetween (1 , 127 ));
87+ private static final Map <String , Float > MEDIUM_TOKENS = Map .of (
88+ "medium1_keep_strict" , 0.5f ,
89+ "medium2_keep_default" , 0.25f
90+ );
8891
89- // Low weight tokens - high freq (pruned under default or strict pruning)
90- map . put ( "common1 " , ( float ) randomIntBetween ( 1 , 127 ));
91- map . put ( "common2 " , ( float ) randomIntBetween ( 1 , 127 ));
92- map . put ( "common3" , ( float ) randomIntBetween ( 1 , 127 ) );
92+ private static final Map < String , Float > RARE_TOKENS = Map . of (
93+ "rare1_keep_strict " , 0.9f ,
94+ "rare2_keep_strict " , 0.85f
95+ );
9396
94- return map ;
97+ @ Override
98+ protected Object getSampleValueForDocument () {
99+ return new TreeMap <>(
100+ randomMap (1 , 5 , () -> Tuple .tuple (randomAlphaOfLengthBetween (5 , 10 ), Float .valueOf (randomIntBetween (1 , 127 ))))
101+ );
95102 }
96103
97104 @ Override
@@ -722,30 +729,17 @@ private void withSearchExecutionContext(MapperService mapperService, CheckedCons
722729 try (Directory directory = newDirectory ()) {
723730 RandomIndexWriter iw = new RandomIndexWriter (random (), directory );
724731
725- Map <String , Float > commonTokens = new TreeMap <>();
726- commonTokens .put ("common1" , 0.1f );
727- commonTokens .put ("common2" , 0.1f );
728- commonTokens .put ("common3" , 0.1f );
729-
730- Map <String , Float > mediumTokens = new TreeMap <>();
731- mediumTokens .put ("medium1" , 0.5f );
732- mediumTokens .put ("medium2" , 0.25f );
733-
734- Map <String , Float > rareTokens = new TreeMap <>();
735- rareTokens .put ("rare1" , 0.9f );
736- rareTokens .put ("rare2" , 0.85f );
737-
738732 int commonDocs = 20 ;
739733 for (int i = 0 ; i < commonDocs ; i ++) {
740- iw .addDocument (mapper .parse (source (b -> b .field ("field" , commonTokens ))).rootDoc ());
734+ iw .addDocument (mapper .parse (source (b -> b .field ("field" , COMMON_TOKENS ))).rootDoc ());
741735 }
742736
743737 int mediumDocs = 5 ;
744738 for (int i = 0 ; i < mediumDocs ; i ++) {
745- iw .addDocument (mapper .parse (source (b -> b .field ("field" , mediumTokens ))).rootDoc ());
739+ iw .addDocument (mapper .parse (source (b -> b .field ("field" , MEDIUM_TOKENS ))).rootDoc ());
746740 }
747741
748- iw .addDocument (mapper .parse (source (b -> b .field ("field" , rareTokens ))).rootDoc ());
742+ iw .addDocument (mapper .parse (source (b -> b .field ("field" , RARE_TOKENS ))).rootDoc ());
749743
750744 // This will lower the averageTokenFreqRatio so that common tokens get pruned with default settings
751745 Map <String , Float > uniqueDoc = new TreeMap <>();
@@ -762,13 +756,13 @@ private void withSearchExecutionContext(MapperService mapperService, CheckedCons
762756 }
763757 }
764758
765- public void testTypeQueryFinalizationPruningScenarios () throws Exception {
759+ public void testPruningScenarios () throws Exception {
766760 for (int i = 0 ; i < 60 ; i ++) {
767- runTestTypeQueryFinalization (randomFrom (IndexPruningScenario .values ()), randomFrom (QueryPruningScenario .values ()));
761+ assertPruningScenario (randomFrom (IndexPruningScenario .values ()), randomFrom (QueryPruningScenario .values ()));
768762 }
769763 }
770764
771- public void testTypeQueryFinalizationDefaultsPreviousVersion () throws Exception {
765+ public void testPruningDefaultsPreIndexOptions () throws Exception {
772766 IndexVersion version = IndexVersionUtils .randomVersionBetween (
773767 random (),
774768 UPGRADE_TO_LUCENE_10_0_0 ,
@@ -791,7 +785,8 @@ public void testTypeQueryFinalizationDefaultsPreviousVersion() throws Exception
791785 queryPruneConfig .v2 ()
792786 );
793787 // query should _not_ be pruned by default on older index versions
794- assertQueryWasPruned (finalizedQuery , PruningScenario .NO_PRUNING );
788+ List <Query > expectedQueryClauses = getExpectedQueryClauses (ft , PruningScenario .NO_PRUNING , context );
789+ assertQueryContains (expectedQueryClauses , finalizedQuery );
795790 });
796791
797792 }
@@ -805,12 +800,14 @@ private XContentBuilder getIndexMapping(IndexPruningScenario pruningScenario) th
805800 };
806801 }
807802
808- private void assertQueryWasPruned (Query query , PruningScenario pruningScenario ) {
809- switch (pruningScenario ) {
810- case NO_PRUNING -> assertQueryHasClauseCount (query , QUERY_VECTORS .size ());
811- case DEFAULT_PRUNING -> assertQueryHasClauseCount (query , QUERY_VECTORS .size () - 3 ); // 3 common tokens pruned
812- case STRICT_PRUNING -> assertQueryHasClauseCount (query , QUERY_VECTORS .size () - 4 ); // 3 common and 1 medium tokens pruned
813- }
803+ private void assertQueryContains (List <Query > expectedClauses , Query query ) {
804+ SparseVectorQueryWrapper queryWrapper = (SparseVectorQueryWrapper ) query ;
805+ var termsQuery = queryWrapper .getTermsQuery ();
806+ assertNotNull (termsQuery );
807+ var booleanQuery = (BooleanQuery ) termsQuery ;
808+
809+ Collection <Query > shouldClauses = booleanQuery .getClauses (BooleanClause .Occur .SHOULD );
810+ assertThat (shouldClauses , Matchers .containsInAnyOrder (expectedClauses .toArray ()));
814811 }
815812
816813 private void assertQueryHasClauseCount (Query query , int clauseCount ) {
@@ -880,7 +877,22 @@ private Tuple<Boolean, TokenPruningConfig> getQueryPruneConfig(QueryPruningScena
880877 };
881878 }
882879
883- private void runTestTypeQueryFinalization (IndexPruningScenario indexPruningScenario , QueryPruningScenario queryPruningScenario )
880+ private List <Query > getExpectedQueryClauses (SparseVectorFieldMapper .SparseVectorFieldType ft , PruningScenario pruningScenario , SearchExecutionContext searchExecutionContext ) {
881+ List <WeightedToken > tokens = switch (pruningScenario ) {
882+ case NO_PRUNING -> QUERY_VECTORS ;
883+ case DEFAULT_PRUNING -> QUERY_VECTORS .stream ()
884+ .filter (t -> t .token ().startsWith ("rare" ) || t .token ().startsWith ("medium" )).toList ();
885+ case STRICT_PRUNING -> QUERY_VECTORS .stream ()
886+ .filter (t -> t .token ().endsWith ("keep_strict" )).toList ();
887+ };
888+
889+ return tokens .stream ().map (t -> {
890+ Query termQuery = ft .termQuery (t .token (), searchExecutionContext );
891+ return new BoostQuery (termQuery , t .weight ());
892+ }).collect (Collectors .toUnmodifiableList ());
893+ }
894+
895+ private void assertPruningScenario (IndexPruningScenario indexPruningScenario , QueryPruningScenario queryPruningScenario )
884896 throws IOException {
885897 logger .debug ("Running test with indexPruningScenario: {}, queryPruningScenario: {}" , indexPruningScenario , queryPruningScenario );
886898 IndexVersion indexVersion = IndexVersionUtils .randomVersionBetween (
@@ -896,14 +908,15 @@ private void runTestTypeQueryFinalization(IndexPruningScenario indexPruningScena
896908 SparseVectorFieldMapper .SparseVectorFieldType ft = (SparseVectorFieldMapper .SparseVectorFieldType ) mapperService .fieldType (
897909 "field"
898910 );
911+ List <Query > expectedQueryClauses = getExpectedQueryClauses (ft , effectivePruningScenario , context );
899912 Query finalizedQuery = ft .finalizeSparseVectorQuery (
900913 context ,
901914 "field" ,
902915 QUERY_VECTORS ,
903916 queryPruneConfig .v1 (),
904917 queryPruneConfig .v2 ()
905918 );
906- assertQueryWasPruned ( finalizedQuery , effectivePruningScenario );
919+ assertQueryContains ( expectedQueryClauses , finalizedQuery );
907920 });
908921 }
909922
@@ -917,17 +930,10 @@ private IndexVersion getIndexVersionForTest(boolean usePreviousIndex) {
917930 : IndexVersionUtils .randomVersionBetween (random (), SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT , IndexVersion .current ());
918931 }
919932
920- private static List <WeightedToken > QUERY_VECTORS = List .of (
921- new WeightedToken ("rare1" , 0.9f ),
922- new WeightedToken ("rare2" , 0.85f ),
923-
924- new WeightedToken ("medium1" , 0.5f ), // this will survive strict pruning, due to higher weight
925- new WeightedToken ("medium2" , 0.25f ),
926-
927- new WeightedToken ("common1" , 0.2f ),
928- new WeightedToken ("common2" , 0.15f ),
929- new WeightedToken ("common3" , 0.1f )
930- );
933+ private static final List <WeightedToken > QUERY_VECTORS = Stream .of (RARE_TOKENS , MEDIUM_TOKENS , COMMON_TOKENS )
934+ .flatMap (map -> map .entrySet ().stream ())
935+ .map (entry -> new WeightedToken (entry .getKey (), entry .getValue ()))
936+ .collect (Collectors .toList ());
931937
932938 /**
933939 * Handles float/double conversion when reading/writing with xcontent by converting all numbers to floats.
0 commit comments