@@ -1628,10 +1628,6 @@ private void testFullTextFunctionWithPushableConjunction(FullTextFunctionTestCas
16281628 assertEquals (expected .toString (), esQuery .query ().toString ());
16291629 }
16301630
1631- public void testKnn () {
1632- testFullTextFunctionWithNonPushableDisjunction (new KnnFunctionTestCase ());
1633- }
1634-
16351631 private void testFullTextFunctionWithNonPushableDisjunction (FullTextFunctionTestCase testCase ) {
16361632 String query = String .format (Locale .ROOT , """
16371633 from test
@@ -1652,32 +1648,6 @@ private void testFullTextFunctionWithNonPushableDisjunction(FullTextFunctionTest
16521648 assertThat (fieldExtract .child (), instanceOf (EsQueryExec .class ));
16531649 }
16541650
1655- public void testKnnPrefilters () {
1656- String query = """
1657- from test
1658- | where knn(dense_vector, [0, 1, 2], 10) and integer > 10
1659- """ ;
1660- var plan = plannerOptimizer .plan (query , IS_SV_STATS , makeAnalyzer ("mapping-all-types.json" ));
1661-
1662- var limit = as (plan , LimitExec .class );
1663- var exchange = as (limit .child (), ExchangeExec .class );
1664- var project = as (exchange .child (), ProjectExec .class );
1665- var field = as (project .child (), FieldExtractExec .class );
1666- var queryExec = as (field .child (), EsQueryExec .class );
1667- QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery (
1668- query ,
1669- unscore (rangeQuery ("integer" ).gt (10 )),
1670- "integer" ,
1671- new Source (2 , 45 , "integer > 10" )
1672- );
1673- KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder ("dense_vector" , new float [] { 0 , 1 , 2 }, 10 , null , null , null )
1674- .addFilterQuery (expectedFilterQueryBuilder );
1675- var expectedQuery = boolQuery ()
1676- .must (expectedKnnQueryBuilder )
1677- .must (expectedFilterQueryBuilder );
1678- assertEquals (expectedQuery .toString (), queryExec .query ().toString ());
1679- }
1680-
16811651 private void testFullTextFunctionWithPushableDisjunction (FullTextFunctionTestCase testCase ) {
16821652 String query = String .format (Locale .ROOT , """
16831653 from test
@@ -1697,12 +1667,11 @@ private void testFullTextFunctionWithPushableDisjunction(FullTextFunctionTestCas
16971667 }
16981668
16991669 private FullTextFunctionTestCase randomFullTextFunctionTestCase () {
1700- return switch (randomIntBetween (0 , 4 )) {
1670+ return switch (randomIntBetween (0 , 3 )) {
17011671 case 0 -> new MatchFunctionTestCase ();
17021672 case 1 -> new MatchOperatorTestCase ();
17031673 case 2 -> new KqlFunctionTestCase ();
17041674 case 3 -> new QueryStringFunctionTestCase ();
1705- case 4 -> new KnnFunctionTestCase ();
17061675 default -> throw new IllegalStateException ("Unexpected value" );
17071676 };
17081677 }
@@ -1861,6 +1830,150 @@ public void testFullTextFunctionWithStatsBy(FullTextFunctionTestCase testCase) {
18611830 aggExec .forEachDown (EsQueryExec .class , esQueryExec -> { assertNull (esQueryExec .query ()); });
18621831 }
18631832
1833+ public void testKnnPrefilters () {
1834+ assumeTrue ("knn must be enabled" , EsqlCapabilities .Cap .KNN_FUNCTION_V3 .isEnabled ());
1835+
1836+ String query = """
1837+ from test
1838+ | where knn(dense_vector, [0, 1, 2], 10) and integer > 10
1839+ """ ;
1840+ var plan = plannerOptimizer .plan (query , IS_SV_STATS , makeAnalyzer ("mapping-all-types.json" ));
1841+
1842+ var limit = as (plan , LimitExec .class );
1843+ var exchange = as (limit .child (), ExchangeExec .class );
1844+ var project = as (exchange .child (), ProjectExec .class );
1845+ var field = as (project .child (), FieldExtractExec .class );
1846+ var queryExec = as (field .child (), EsQueryExec .class );
1847+ QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery (
1848+ query ,
1849+ unscore (rangeQuery ("integer" ).gt (10 )),
1850+ "integer" ,
1851+ new Source (2 , 45 , "integer > 10" )
1852+ );
1853+ KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder (
1854+ "dense_vector" ,
1855+ new float [] { 0 , 1 , 2 },
1856+ 10 ,
1857+ null ,
1858+ null ,
1859+ null
1860+ ).addFilterQuery (expectedFilterQueryBuilder );
1861+ var expectedQuery = boolQuery ().must (expectedKnnQueryBuilder ).must (expectedFilterQueryBuilder );
1862+ assertEquals (expectedQuery .toString (), queryExec .query ().toString ());
1863+ }
1864+
1865+ public void testPushDownConjunctionsToKnnPrefilter () {
1866+ assumeTrue ("knn must be enabled" , EsqlCapabilities .Cap .KNN_FUNCTION_V3 .isEnabled ());
1867+
1868+ String query = """
1869+ from test
1870+ | where knn(dense_vector, [0, 1, 2], 10) and integer > 10
1871+ """ ;
1872+ var plan = plannerOptimizer .plan (query , IS_SV_STATS , makeAnalyzer ("mapping-all-types.json" ));
1873+
1874+ var limit = as (plan , LimitExec .class );
1875+ var exchange = as (limit .child (), ExchangeExec .class );
1876+ var project = as (exchange .child (), ProjectExec .class );
1877+ var field = as (project .child (), FieldExtractExec .class );
1878+ var queryExec = as (field .child (), EsQueryExec .class );
1879+
1880+ // The filter condition should be pushed down to both the KNN query and the main query
1881+ QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery (
1882+ query ,
1883+ unscore (rangeQuery ("integer" ).gt (10 )),
1884+ "integer" ,
1885+ new Source (2 , 45 , "integer > 10" )
1886+ );
1887+
1888+ KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder (
1889+ "dense_vector" ,
1890+ new float [] { 0 , 1 , 2 },
1891+ 10 ,
1892+ null ,
1893+ null ,
1894+ null
1895+ ).addFilterQuery (expectedFilterQueryBuilder );
1896+
1897+ var expectedQuery = boolQuery ().must (expectedKnnQueryBuilder ).must (expectedFilterQueryBuilder );
1898+
1899+ assertEquals (expectedQuery .toString (), queryExec .query ().toString ());
1900+ }
1901+
1902+ public void testNotPushDownDisjunctionsToKnnPrefilter () {
1903+ assumeTrue ("knn must be enabled" , EsqlCapabilities .Cap .KNN_FUNCTION_V3 .isEnabled ());
1904+
1905+ String query = """
1906+ from test
1907+ | where knn(dense_vector, [0, 1, 2], 10) or integer > 10
1908+ """ ;
1909+ var plan = plannerOptimizer .plan (query , IS_SV_STATS , makeAnalyzer ("mapping-all-types.json" ));
1910+
1911+ var limit = as (plan , LimitExec .class );
1912+ var exchange = as (limit .child (), ExchangeExec .class );
1913+ var project = as (exchange .child (), ProjectExec .class );
1914+ var field = as (project .child (), FieldExtractExec .class );
1915+ var queryExec = as (field .child (), EsQueryExec .class );
1916+
1917+ // The disjunction should not be pushed down to the KNN query
1918+ KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder ("dense_vector" , new float [] { 0 , 1 , 2 }, 10 , null , null , null );
1919+ QueryBuilder rangeQueryBuilder = wrapWithSingleQuery (
1920+ query ,
1921+ unscore (rangeQuery ("integer" ).gt (10 )),
1922+ "integer" ,
1923+ new Source (2 , 44 , "integer > 10" )
1924+ );
1925+
1926+ var expectedQuery = boolQuery ().should (knnQueryBuilder ).should (rangeQueryBuilder );
1927+
1928+ assertEquals (expectedQuery .toString (), queryExec .query ().toString ());
1929+ }
1930+
1931+ public void testMultipleKnnQueriesInPrefilters () {
1932+ assumeTrue ("knn must be enabled" , EsqlCapabilities .Cap .KNN_FUNCTION_V3 .isEnabled ());
1933+
1934+ String query = """
1935+ from test
1936+ | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
1937+ """ ;
1938+ var plan = plannerOptimizer .plan (query , IS_SV_STATS , makeAnalyzer ("mapping-all-types.json" ));
1939+
1940+ var limit = as (plan , LimitExec .class );
1941+ var exchange = as (limit .child (), ExchangeExec .class );
1942+ var project = as (exchange .child (), ProjectExec .class );
1943+ var field = as (project .child (), FieldExtractExec .class );
1944+ var queryExec = as (field .child (), EsQueryExec .class );
1945+
1946+ KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder ("dense_vector" , new float [] { 0 , 1 , 2 }, 10 , null , null , null );
1947+ // Integer range query (right side of first OR)
1948+ QueryBuilder integerRangeQuery = wrapWithSingleQuery (
1949+ query ,
1950+ unscore (rangeQuery ("integer" ).gt (10 )),
1951+ "integer" ,
1952+ new Source (2 , 45 , "integer > 10" )
1953+ );
1954+
1955+ // Second KNN query (right side of second OR)
1956+ KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder ("dense_vector" , new float [] { 4 , 5 , 6 }, 10 , null , null , null );
1957+
1958+ // Keyword term query (left side of second OR)
1959+ QueryBuilder keywordQuery = wrapWithSingleQuery (
1960+ query ,
1961+ unscore (termQuery ("keyword" , "test" )),
1962+ "keyword" ,
1963+ new Source (2 , 87 , "keyword == \" test\" " )
1964+ );
1965+
1966+ // First OR (knn1 OR integer > 10)
1967+ var firstOr = boolQuery ().should (firstKnnQuery ).should (integerRangeQuery );
1968+ // Second OR (keyword == "test" OR knn2)
1969+ var secondOr = boolQuery ().should (keywordQuery ).should (secondKnnQuery .addFilterQuery (firstOr ));
1970+ firstKnnQuery .addFilterQuery (secondOr );
1971+
1972+ // Top-level AND combining both ORs
1973+ var expectedQuery = boolQuery ().must (firstOr ).must (secondOr );
1974+ assertEquals (expectedQuery .toString (), queryExec .query ().toString ());
1975+ }
1976+
18641977 public void testParallelizeTimeSeriesPlan () {
18651978 assumeTrue ("requires snapshot builds" , Build .current ().isSnapshot ());
18661979 var query = "TS k8s | STATS max(rate(network.total_bytes_in)) BY bucket(@timestamp, 1h)" ;
0 commit comments