diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/subquery/uncorrelated/IoTDBUncorrelatedQuantifiedComparisonIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/subquery/uncorrelated/IoTDBUncorrelatedQuantifiedComparisonIT.java new file mode 100644 index 000000000000..11eb1feab901 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/subquery/uncorrelated/IoTDBUncorrelatedQuantifiedComparisonIT.java @@ -0,0 +1,674 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.relational.it.query.recent.subquery.uncorrelated; + +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.TableClusterIT; +import org.apache.iotdb.itbase.category.TableLocalStandaloneIT; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData; +import static org.apache.iotdb.db.it.utils.TestUtils.tableAssertTestFail; +import static org.apache.iotdb.db.it.utils.TestUtils.tableResultSetEqualTest; +import static org.apache.iotdb.relational.it.query.recent.subquery.SubqueryDataSetUtils.CREATE_SQLS; +import static org.apache.iotdb.relational.it.query.recent.subquery.SubqueryDataSetUtils.DATABASE_NAME; +import static org.apache.iotdb.relational.it.query.recent.subquery.SubqueryDataSetUtils.NUMERIC_MEASUREMENTS; + +@RunWith(IoTDBTestRunner.class) +@Category({TableLocalStandaloneIT.class, TableClusterIT.class}) +public class IoTDBUncorrelatedQuantifiedComparisonIT { + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().getConfig().getCommonConfig().setSortBufferSize(128 * 1024); + EnvFactory.getEnv().getConfig().getCommonConfig().setMaxTsBlockSizeInByte(4 * 1024); + EnvFactory.getEnv().initClusterEnvironment(); + prepareTableData(CREATE_SQLS); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void testAnyAndSomeComparisonInWhereClauseWithoutNull() { + String sql; + String[] expectedHeader; + String[] retArray; + + // Test case: where s > any (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s > any (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s > some (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s > some (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s >= any (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s >= any (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"30,", "40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s >= some (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s >= some (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"30,", "40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s < any (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s < any (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"30,", "40,", "50,", "60,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where < some (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s < some (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"30,", "40,", "50,", "60,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s <= any (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s <= any (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"30,", "40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s <= some (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s <= some (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"30,", "40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s = any (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s = any (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"30,", "40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s = some (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s = some (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"30,", "40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s != any (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s != any (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"30,", "40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s != some (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s != some (SELECT %s FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"30,", "40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + } + + @Test + public void testAllComparisonInWhereClauseWithoutNull() { + String sql; + String[] expectedHeader; + String[] retArray; + + // Test case: where s > all (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s > all (SELECT %s FROM table3 WHERE device_id = 'd01')"; + retArray = new String[] {"50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s >= all (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s >= all (SELECT %s FROM table3 WHERE device_id = 'd01')"; + retArray = new String[] {"40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s < all (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s < all (SELECT %s FROM table3 WHERE device_id = 'd01')"; + retArray = new String[] {}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s <= all (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s <= all (SELECT %s FROM table3 WHERE device_id = 'd01')"; + retArray = new String[] {"30,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s = all (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s = all (SELECT %s FROM table3 WHERE device_id = 'd01')"; + retArray = new String[] {}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s != all (subquery), s does not contain null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s != all (SELECT %s FROM table3 WHERE device_id = 'd01')"; + retArray = new String[] {"50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + } + + @Test + public void testAnyAndSomeComparisonInWhereClauseWithNull() { + String sql; + String[] expectedHeader; + String[] retArray; + + // Test case: where s1 > any (subquery), s1 in table3 contains null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s > any (SELECT s1 FROM table3)"; + retArray = new String[] {"40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s1 > some (subquery), s1 in table3 contains null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s > some (SELECT s1 FROM table3)"; + retArray = new String[] {"40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s1 >= any (subquery), s1 in table3 contains null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s >= any (SELECT s1 FROM table3)"; + retArray = new String[] {"30,", "40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s1 >= some (subquery), s1 in table3 contains null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s >= some (SELECT s1 FROM table3)"; + retArray = new String[] {"30,", "40,", "50,", "60,", "70,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s1 < any (subquery), s1 in table3 contains null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s < any (SELECT s1 FROM table3)"; + retArray = new String[] {"30,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s1 < some (subquery), s1 in table3 contains null value + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s < some (SELECT s1 FROM table3)"; + retArray = new String[] {"30,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + } + + @Test + public void testAllComparisonInWhereClauseWithNull() { + String sql; + String[] expectedHeader; + String[] retArray; + + // Test case: where s1 > all (subquery), s1 in table3 contains null value. + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s > all (SELECT s1 FROM table3)"; + retArray = new String[] {}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s1 >= all (subquery), s1 in table3 contains null value. + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s >= all (SELECT s1 FROM table3)"; + retArray = new String[] {}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s1 < all (subquery), s1 in table3 contains null value. + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s < all (SELECT s1 FROM table3)"; + retArray = new String[] {}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s1 <= all (subquery), s1 in table3 contains null value. + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s <= all (SELECT s1 FROM table3)"; + retArray = new String[] {}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s1 = all (subquery), s1 in table3 contains null value. + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s = all (SELECT s1 FROM table3)"; + retArray = new String[] {}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: where s1 != all (subquery), s1 in table3 contains null value. + sql = + "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01' and cast(%s as INT32) != all (SELECT s1 FROM table3)"; + retArray = new String[] {}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + } + + @Test + public void testQuantifiedComparisonInWhereWithExpression() { + String sql; + String[] expectedHeader; + String[] retArray; + + sql = + "SELECT cast(%s + 10 AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s + 10 > any (SELECT %s + 10 FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"50,", "60,", "70,", "80,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + sql = + "SELECT cast(%s + 10 AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s + 10 > some (SELECT %s + 10 FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"50,", "60,", "70,", "80,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + sql = + "SELECT cast(%s + 10 AS INT32) as %s FROM table1 WHERE device_id = 'd01' and %s + 10 >= all (SELECT %s + 10 FROM table1 WHERE device_id = 'd01')"; + retArray = new String[] {"80,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + expectedHeader = new String[] {measurement}; + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + } + + @Test + public void testQuantifiedComparisonInHavingClause() { + String sql; + String[] expectedHeader; + String[] retArray; + + // Test case: having s >= any(subquery) + sql = + "SELECT device_id, count(*) from table1 group by device_id having count(*) + 25 >= any(SELECT cast(s1 as INT64) from table3 where device_id = 'd01')"; + expectedHeader = new String[] {"device_id", "_col1"}; + retArray = + new String[] { + "d01,5,", "d03,5,", "d05,5,", "d07,5,", "d09,5,", "d11,5,", "d13,5,", "d15,5," + }; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: having s >= some(subquery) + sql = + "SELECT device_id, count(*) from table1 group by device_id having count(*) + 25 >= some(SELECT cast(s1 as INT64) from table3 where device_id = 'd01')"; + expectedHeader = new String[] {"device_id", "_col1"}; + retArray = + new String[] { + "d01,5,", "d03,5,", "d05,5,", "d07,5,", "d09,5,", "d11,5,", "d13,5,", "d15,5," + }; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + + // Test case: having s >= all(subquery) + sql = + "SELECT device_id, count(*) from table1 group by device_id having count(*) + 35 >= all(SELECT cast(s1 as INT64) from table3 where device_id = 'd01')"; + expectedHeader = new String[] {"device_id", "_col1"}; + retArray = + new String[] { + "d01,5,", "d03,5,", "d05,5,", "d07,5,", "d09,5,", "d11,5,", "d13,5,", "d15,5," + }; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement, measurement, measurement), + expectedHeader, + retArray, + DATABASE_NAME); + } + } + + public void testQuantifiedComparisonInSelectClause() { + String sql; + String[] expectedHeader; + String[] retArray; + + // Test case: select s > any(subquery) + sql = + "SELECT %s > any(SELECT (%s) from table3 WHERE device_id = 'd01') from table1 where device_id = 'd01'"; + expectedHeader = new String[] {"_col0"}; + retArray = new String[] {"false,", "true,", "true,", "true,", "true,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement), expectedHeader, retArray, DATABASE_NAME); + } + + // Test case: select s > some(subquery) + sql = + "SELECT %s > some(SELECT (%s) from table3 WHERE device_id = 'd01') from table1 where device_id = 'd01'"; + expectedHeader = new String[] {"_col0"}; + retArray = new String[] {"false,", "true,", "true,", "true,", "true,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement), expectedHeader, retArray, DATABASE_NAME); + } + + // Test case: select s > all(subquery) + sql = + "SELECT %s > all(SELECT (%s) from table3 WHERE device_id = 'd01') from table1 where device_id = 'd01'"; + expectedHeader = new String[] {"_col0"}; + retArray = new String[] {"false,", "false,", "false,", "false,", "false,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement), expectedHeader, retArray, DATABASE_NAME); + } + + // Test case: select s < any(subquery), subquery contains null value + sql = "SELECT %s < any(SELECT (%s) from table3) from table1 where device_id = 'd01'"; + expectedHeader = new String[] {"_col0"}; + retArray = new String[] {"null,", "null,", "null,", "null,", "null,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement), expectedHeader, retArray, DATABASE_NAME); + } + + // Test case: select s < some(subquery), subquery contains null value + sql = "SELECT %s < some(SELECT (%s) from table3) from table1 where device_id = 'd01'"; + expectedHeader = new String[] {"_col0"}; + retArray = new String[] {"null,", "null,", "null,", "null,", "null,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement), expectedHeader, retArray, DATABASE_NAME); + } + + // Test case: select s <= any(subquery), subquery contains null value + sql = "SELECT %s <= any(SELECT (%s) from table3) from table1 where device_id = 'd01'"; + expectedHeader = new String[] {"_col0"}; + retArray = new String[] {"true,", "null,", "null,", "null,", "null,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement), expectedHeader, retArray, DATABASE_NAME); + } + + // Test case: select s <= some(subquery), subquery contains null value + sql = "SELECT %s <= some(SELECT (%s) from table3) from table1 where device_id = 'd01'"; + expectedHeader = new String[] {"_col0"}; + retArray = new String[] {"true,", "null,", "null,", "null,", "null,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement), expectedHeader, retArray, DATABASE_NAME); + } + + // Test case: select s != all(subquery), subquery contains null value + sql = "SELECT %s != all(SELECT (%s) from table3) from table1 where device_id = 'd01'"; + expectedHeader = new String[] {"_col0"}; + retArray = new String[] {"false,", "false,", "null,", "null,", "null,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement), expectedHeader, retArray, DATABASE_NAME); + } + + // Test case: select s != all(subquery), subquery result contains null value and s not in + // non-null + // value result set + sql = + "SELECT %s != all(SELECT (%s) from table3 where device_id = 'd_null') from table1 where device_id = 'd02' and %s != 30"; + expectedHeader = new String[] {"_col0"}; + retArray = new String[] {"null,", "null,"}; + for (String measurement : NUMERIC_MEASUREMENTS) { + tableResultSetEqualTest( + String.format(sql, measurement, measurement), expectedHeader, retArray, DATABASE_NAME); + } + } + + @Test + public void testQuantifiedComparisonLegalityCheck() { + // Legality check: only support any/some/all quantifier + tableAssertTestFail( + "select s1 from table1 where s1 > any_value (select s1 from table3)", + "mismatched input", + DATABASE_NAME); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java index 403627ed4617..a9b4122375f5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java @@ -254,6 +254,8 @@ public static TableAccumulator createBuiltinAccumulator( switch (aggregationType) { case COUNT: return new CountAccumulator(); + case COUNT_ALL: + return new CountAllAccumulator(); case COUNT_IF: return new CountIfAccumulator(); case AVG: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/CountAllAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/CountAllAccumulator.java new file mode 100644 index 000000000000..27867e387305 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/CountAllAccumulator.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.RamUsageEstimator; + +import static com.google.common.base.Preconditions.checkArgument; + +public class CountAllAccumulator implements TableAccumulator { + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(CountAllAccumulator.class); + private long countState = 0; + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new CountAllAccumulator(); + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + checkArgument(arguments.length == 1, "argument of CountAll should be one column"); + int count = mask.getSelectedPositionCount(); + countState += count; + } + + @Override + public void removeInput(Column[] arguments) { + checkArgument(arguments.length == 1, "argument of Count should be one column"); + int count = arguments[0].getPositionCount(); + countState -= count; + } + + @Override + public void addIntermediate(Column argument) { + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + countState += argument.getLong(i); + } + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + columnBuilder.writeLong(countState); + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + columnBuilder.writeLong(countState); + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException("CountAllAccumulator does not support statistics."); + } + + @Override + public void reset() { + countState = 0; + } + + @Override + public boolean removable() { + return true; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java index e2acc16987a4..55de53919385 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java @@ -493,7 +493,7 @@ protected ColumnTransformer visitGenericLiteral(GenericLiteral node, Context con return res; } - // currently, we only support Date and Timestamp + // currently, we only support Date/Timestamp/INT64 // for Date, GenericLiteral.value is an int value // for Timestamp, GenericLiteral.value is a long value private static ConstantColumnTransformer getColumnTransformerForGenericLiteral( @@ -506,6 +506,10 @@ private static ConstantColumnTransformer getColumnTransformerForGenericLiteral( return new ConstantColumnTransformer( TimestampType.TIMESTAMP, new LongColumn(1, Optional.empty(), new long[] {Long.parseLong(literal.getValue())})); + } else if (INT64.getTypeEnum().name().equals(literal.getType())) { + return new ConstantColumnTransformer( + INT64, + new LongColumn(1, Optional.empty(), new long[] {Long.parseLong(literal.getValue())})); } else { throw new SemanticException("Unsupported type in GenericLiteral: " + literal.getType()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java index ffd248cc845e..e94136cf74c6 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java @@ -1533,9 +1533,11 @@ public Operator visitSemiJoin(SemiJoinNode node, LocalExecutionPlanContext conte Type sourceJoinKeyType = context.getTypeProvider().getTableModelType(node.getSourceJoinSymbol()); + checkIfJoinKeyTypeMatches( sourceJoinKeyType, context.getTypeProvider().getTableModelType(node.getFilteringSourceJoinSymbol())); + OperatorContext operatorContext = context .getDriverContext() diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index cfac519f769b..607b1c371355 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -628,6 +628,7 @@ && isIntegerNumber(argumentTypes.get(2)))) { // get return type switch (functionName.toLowerCase(Locale.ENGLISH)) { case SqlConstant.COUNT: + case SqlConstant.COUNT_ALL: case SqlConstant.COUNT_IF: return INT64; case SqlConstant.FIRST_AGGREGATION: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrTypeAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrTypeAnalyzer.java index 1a19f40132f1..1ecf2dad77dc 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrTypeAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrTypeAnalyzer.java @@ -340,7 +340,7 @@ protected Type visitLongLiteral(LongLiteral node, Context context) { && node.getParsedValue() <= Integer.MAX_VALUE) { return setExpressionType(node, INT32); } - + // keep the original type return setExpressionType(node, INT64); } @@ -361,6 +361,8 @@ protected Type visitGenericLiteral(GenericLiteral node, Context context) { type = DateType.DATE; } else if (TimestampType.TIMESTAMP.getTypeEnum().name().equals(node.getType())) { type = TimestampType.TIMESTAMP; + } else if (INT64.getTypeEnum().name().equals(node.getType())) { + type = INT64; } else { throw new SemanticException("Unsupported type in GenericLiteral: " + node.getType()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SimplePlanRewriter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SimplePlanRewriter.java new file mode 100644 index 000000000000..d2a8ec781a19 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SimplePlanRewriter.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanVisitor; + +import com.google.common.collect.ImmutableList; + +import static com.google.common.base.Verify.verifyNotNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.ChildReplacer.replaceChildren; + +public abstract class SimplePlanRewriter + extends PlanVisitor> { + public static PlanNode rewriteWith(SimplePlanRewriter rewriter, PlanNode node) { + return node.accept(rewriter, new RewriteContext<>(rewriter, null)); + } + + public static PlanNode rewriteWith(SimplePlanRewriter rewriter, PlanNode node, C context) { + return node.accept(rewriter, new RewriteContext<>(rewriter, context)); + } + + @Override + public PlanNode visitPlan(PlanNode node, RewriteContext context) { + return context.defaultRewrite(node, context.get()); + } + + public static class RewriteContext { + private final C userContext; + private final SimplePlanRewriter nodeRewriter; + + private RewriteContext(SimplePlanRewriter nodeRewriter, C userContext) { + this.nodeRewriter = nodeRewriter; + this.userContext = userContext; + } + + public C get() { + return userContext; + } + + /** + * Invoke the rewrite logic recursively on children of the given node and swap it out with an + * identical copy with the rewritten children + */ + public PlanNode defaultRewrite(PlanNode node) { + return defaultRewrite(node, null); + } + + /** + * Invoke the rewrite logic recursively on children of the given node and swap it out with an + * identical copy with the rewritten children + */ + public PlanNode defaultRewrite(PlanNode node, C context) { + ImmutableList.Builder children = + ImmutableList.builderWithExpectedSize(node.getChildren().size()); + node.getChildren().forEach(source -> children.add(rewrite(source, context))); + return replaceChildren(node, children.build()); + } + + /** This method is meant for invoking the rewrite logic on children while processing a node */ + public PlanNode rewrite(PlanNode node, C userContext) { + PlanNode result = node.accept(nodeRewriter, new RewriteContext<>(nodeRewriter, userContext)); + return verifyNotNull(result, "nodeRewriter returned null for %s", node.getClass().getName()); + } + + /** This method is meant for invoking the rewrite logic on children while processing a node */ + public PlanNode rewrite(PlanNode node) { + return rewrite(node, null); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java index 62bdd1bf1b9d..b3766383c180 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java @@ -208,6 +208,7 @@ public LogicalOptimizeFactory(PlannerContext plannerContext) { new UnaliasSymbolReferences(plannerContext.getMetadata()), columnPruningOptimizer, inlineProjectionLimitFiltersOptimizer, + new TransformQuantifiedComparisonApplyToCorrelatedJoin(metadata), new IterativeOptimizer( plannerContext, ruleStats, @@ -215,6 +216,13 @@ public LogicalOptimizeFactory(PlannerContext plannerContext) { new RemoveRedundantEnforceSingleRowNode(), new RemoveUnreferencedScalarSubqueries(), new TransformUncorrelatedSubqueryToJoin(), new TransformUncorrelatedInPredicateSubqueryToSemiJoin())), + new IterativeOptimizer( + plannerContext, + ruleStats, + ImmutableSet.of( + new InlineProjections(plannerContext), new RemoveRedundantIdentityProjections() + /*new TransformCorrelatedSingleRowSubqueryToProject(), + new RemoveAggregationInSemiJoin())*/ )), new CheckSubqueryNodesAreRewritten(), new IterativeOptimizer( plannerContext, ruleStats, ImmutableSet.of(new PruneDistinctAggregation())), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java index 7cf45ff3cbd3..f039c09a701d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java @@ -212,7 +212,7 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) { Expression predicate = combineConjuncts(node.getPredicate(), context.inheritedPredicate); // when exist diff function, predicate can not be pushed down into DeviceTableScanNode - if (containsDiffFunction(predicate)) { + if (containsDiffFunction(predicate) || canNotPushDownBelowProjectNode(node, predicate)) { node.setChild(node.getChild().accept(this, new RewriteContext())); node.setPredicate(predicate); context.inheritedPredicate = TRUE_LITERAL; @@ -234,6 +234,35 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) { return node; } + private boolean canNotPushDownBelowProjectNode(FilterNode node, Expression predicate) { + PlanNode child = node.getChild(); + if (child instanceof ProjectNode) { + // if the inheritedPredicate is not in the output of the child of ProjectNode, we can not + // push it down for now. + // (predicate will be computed by the ProjectNode, Trino will rewrite the predicate in + // visitProject, but we have not implemented this for now.) + Set outputSymbolsOfProjectChild = + new HashSet<>(((ProjectNode) child).getChild().getOutputSymbols()); + return missingTermsInOutputSymbols(predicate, outputSymbolsOfProjectChild); + } + return false; + } + + private boolean missingTermsInOutputSymbols(Expression expression, Set outputSymbols) { + if (expression instanceof SymbolReference) { + return !outputSymbols.contains(Symbol.from(expression)); + } + if (!expression.getChildren().isEmpty()) { + for (Node node : expression.getChildren()) { + if (missingTermsInOutputSymbols((Expression) node, outputSymbols)) { + return true; + } + } + } + + return false; + } + // private boolean areExpressionsEquivalent( // Expression leftExpression, Expression rightExpression) { // return false; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java new file mode 100644 index 000000000000..00fecc1a69a4 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations; + +import org.apache.iotdb.db.queryengine.common.QueryId; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.relational.function.BoundSignature; +import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionId; +import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionKind; +import org.apache.iotdb.db.queryengine.plan.relational.metadata.FunctionNullability; +import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; +import org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments; +import org.apache.iotdb.db.queryengine.plan.relational.planner.SimplePlanRewriter; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator; +import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ApplyNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Cast; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.GenericLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NullLiteral; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SearchedCaseExpression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SimpleCaseExpression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.WhenClause; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.tsfile.read.common.type.LongType; +import org.apache.tsfile.read.common.type.Type; + +import java.util.EnumSet; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.SimplePlanRewriter.rewriteWith; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.combineConjuncts; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.globalAggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.singleAggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.ApplyNode.Quantifier.ALL; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.FALSE_LITERAL; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.EQUAL; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.GREATER_THAN; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.LESS_THAN; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.NOT_EQUAL; +import static org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignatureTranslator.toSqlType; +import static org.apache.tsfile.read.common.type.BooleanType.BOOLEAN; + +public class TransformQuantifiedComparisonApplyToCorrelatedJoin implements PlanOptimizer { + private final Metadata metadata; + + public TransformQuantifiedComparisonApplyToCorrelatedJoin(Metadata metadata) { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public PlanNode optimize(PlanNode plan, Context context) { + return rewriteWith( + new Rewriter(context.idAllocator(), context.getSymbolAllocator(), metadata), plan, null); + } + + private static class Rewriter extends SimplePlanRewriter { + private final QueryId idAllocator; + private final SymbolAllocator symbolAllocator; + private final Metadata metadata; + + public Rewriter(QueryId idAllocator, SymbolAllocator symbolAllocator, Metadata metadata) { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public PlanNode visitApply(ApplyNode node, RewriteContext context) { + if (node.getSubqueryAssignments().size() != 1) { + return context.defaultRewrite(node); + } + + ApplyNode.SetExpression expression = getOnlyElement(node.getSubqueryAssignments().values()); + if (expression instanceof ApplyNode.QuantifiedComparison) { + return rewriteQuantifiedApplyNode( + node, (ApplyNode.QuantifiedComparison) expression, context); + } + + return context.defaultRewrite(node); + } + + private PlanNode rewriteQuantifiedApplyNode( + ApplyNode node, + ApplyNode.QuantifiedComparison quantifiedComparison, + RewriteContext context) { + PlanNode subqueryPlan = context.rewrite(node.getSubquery()); + + Symbol outputColumn = getOnlyElement(subqueryPlan.getOutputSymbols()); + Type outputColumnType = symbolAllocator.getTypes().getTableModelType(outputColumn); + checkState(outputColumnType.isOrderable(), "Subquery result type must be orderable"); + + Symbol minValue = symbolAllocator.newSymbol("min", outputColumnType); + Symbol maxValue = symbolAllocator.newSymbol("max", outputColumnType); + Symbol countAllValue = symbolAllocator.newSymbol("count_all", LongType.getInstance()); + Symbol countNonNullValue = + symbolAllocator.newSymbol("count_non_null", LongType.getInstance()); + + List outputColumnReferences = ImmutableList.of(outputColumn.toSymbolReference()); + + subqueryPlan = + singleAggregation( + idAllocator.genPlanNodeId(), + subqueryPlan, + ImmutableMap.of( + minValue, + new AggregationNode.Aggregation( + getResolvedBuiltInAggregateFunction( + "min", ImmutableList.of(outputColumnType)), + outputColumnReferences, + false, + Optional.empty(), + Optional.empty(), + Optional.empty()), + maxValue, + new AggregationNode.Aggregation( + getResolvedBuiltInAggregateFunction( + "max", ImmutableList.of(outputColumnType)), + outputColumnReferences, + false, + Optional.empty(), + Optional.empty(), + Optional.empty()), + countAllValue, + new AggregationNode.Aggregation( + getResolvedBuiltInAggregateFunction( + "count_all", ImmutableList.of(outputColumnType)), + outputColumnReferences, + false, + Optional.empty(), + Optional.empty(), + Optional.empty()), + countNonNullValue, + new AggregationNode.Aggregation( + getResolvedBuiltInAggregateFunction( + "count", ImmutableList.of(outputColumnType)), + outputColumnReferences, + false, + Optional.empty(), + Optional.empty(), + Optional.empty())), + globalAggregation()); + + PlanNode join = + new CorrelatedJoinNode( + node.getPlanNodeId(), + context.rewrite(node.getInput()), + subqueryPlan, + node.getCorrelation(), + JoinNode.JoinType.INNER, + TRUE_LITERAL, + node.getOriginSubquery()); + + Expression valueComparedToSubquery = + rewriteUsingBounds( + quantifiedComparison, minValue, maxValue, countAllValue, countNonNullValue); + + Symbol quantifiedComparisonSymbol = getOnlyElement(node.getSubqueryAssignments().keySet()); + + return projectExpressions( + join, Assignments.of(quantifiedComparisonSymbol, valueComparedToSubquery)); + } + + private ResolvedFunction getResolvedBuiltInAggregateFunction( + String functionName, List argumentTypes) { + // The same as the code in ExpressionAnalyzer + Type type = metadata.getFunctionReturnType(functionName, argumentTypes); + return new ResolvedFunction( + new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), type, argumentTypes), + new FunctionId("noop"), + FunctionKind.AGGREGATE, + true, + FunctionNullability.getAggregationFunctionNullability(argumentTypes.size())); + } + + public Expression rewriteUsingBounds( + ApplyNode.QuantifiedComparison quantifiedComparison, + Symbol minValue, + Symbol maxValue, + Symbol countAllValue, + Symbol countNonNullValue) { + BooleanLiteral emptySetResult; + Function, Expression> quantifier; + if (quantifiedComparison.getQuantifier() == ALL) { + emptySetResult = TRUE_LITERAL; + quantifier = IrUtils::combineConjuncts; + } else { + emptySetResult = FALSE_LITERAL; + quantifier = IrUtils::combineDisjuncts; + } + Expression comparisonWithExtremeValue = + getBoundComparisons(quantifiedComparison, minValue, maxValue); + + return new SimpleCaseExpression( + countAllValue.toSymbolReference(), + ImmutableList.of(new WhenClause(new GenericLiteral("INT64", "0"), emptySetResult)), + quantifier.apply( + ImmutableList.of( + comparisonWithExtremeValue, + new SearchedCaseExpression( + ImmutableList.of( + new WhenClause( + new ComparisonExpression( + NOT_EQUAL, + countAllValue.toSymbolReference(), + countNonNullValue.toSymbolReference()), + new Cast(new NullLiteral(), toSqlType(BOOLEAN)))), + emptySetResult)))); + } + + private Expression getBoundComparisons( + ApplyNode.QuantifiedComparison quantifiedComparison, Symbol minValue, Symbol maxValue) { + if (mapOperator(quantifiedComparison) == EQUAL + && quantifiedComparison.getQuantifier() == ALL) { + // A = ALL B <=> min B = max B && A = min B + return combineConjuncts( + new ComparisonExpression( + EQUAL, minValue.toSymbolReference(), maxValue.toSymbolReference()), + new ComparisonExpression( + EQUAL, + quantifiedComparison.getValue().toSymbolReference(), + maxValue.toSymbolReference())); + } + + if (EnumSet.of(LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL) + .contains(mapOperator(quantifiedComparison))) { + // A < ALL B <=> A < min B + // A > ALL B <=> A > max B + // A < ANY B <=> A < max B + // A > ANY B <=> A > min B + Symbol boundValue = + shouldCompareValueWithLowerBound(quantifiedComparison) ? minValue : maxValue; + return new ComparisonExpression( + mapOperator(quantifiedComparison), + quantifiedComparison.getValue().toSymbolReference(), + boundValue.toSymbolReference()); + } + throw new IllegalArgumentException( + "Unsupported quantified comparison: " + quantifiedComparison); + } + + private static ComparisonExpression.Operator mapOperator( + ApplyNode.QuantifiedComparison quantifiedComparison) { + switch (quantifiedComparison.getOperator()) { + case EQUAL: + return EQUAL; + case NOT_EQUAL: + return NOT_EQUAL; + case LESS_THAN: + return LESS_THAN; + case LESS_THAN_OR_EQUAL: + return LESS_THAN_OR_EQUAL; + case GREATER_THAN: + return GREATER_THAN; + case GREATER_THAN_OR_EQUAL: + return GREATER_THAN_OR_EQUAL; + default: + throw new IllegalArgumentException( + "Unexpected quantifiedComparison: " + quantifiedComparison.getOperator()); + } + } + + private static boolean shouldCompareValueWithLowerBound( + ApplyNode.QuantifiedComparison quantifiedComparison) { + ComparisonExpression.Operator operator = mapOperator(quantifiedComparison); + switch (quantifiedComparison.getQuantifier()) { + case ALL: + switch (operator) { + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + return true; + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + return false; + default: + throw new IllegalArgumentException("Unexpected value: " + operator); + } + case ANY: + case SOME: + switch (operator) { + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + return false; + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + return true; + default: + throw new IllegalArgumentException("Unexpected value: " + operator); + } + default: + throw new IllegalArgumentException( + "Unexpected Quantifier: " + quantifiedComparison.getQuantifier()); + } + } + + private ProjectNode projectExpressions(PlanNode input, Assignments subqueryAssignments) { + Assignments assignments = + Assignments.builder() + .putIdentities(input.getOutputSymbols()) + .putAll(subqueryAssignments) + .build(); + return new ProjectNode(idAllocator.genPlanNodeId(), input, assignments); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java index 43267f9e48e1..cad8f229c4eb 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java @@ -62,6 +62,7 @@ protected SqlConstant() { public static final String LAST_AGGREGATION = "last"; public static final String FIRST_AGGREGATION = "first"; public static final String COUNT = "count"; + public static final String COUNT_ALL = "count_all"; public static final String AVG = "avg"; public static final String SUM = "sum"; public static final String COUNT_IF = "count_if"; diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SubqueryTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SubqueryTest.java index 52321265280b..96dbe9d2b149 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SubqueryTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SubqueryTest.java @@ -56,6 +56,7 @@ import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.Step.FINAL; import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.Step.INTERMEDIATE; import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.Step.PARTIAL; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.Step.SINGLE; import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.EQUAL; public class SubqueryTest { @@ -322,4 +323,215 @@ public void testUncorrelatedNotInPredicateSubquery() { filterPredicate, semiJoin("s1", "s1_6", "expr", sort(tableScan1), sort(tableScan2)))))); } + + @Test + public void testUncorrelatedAnyComparisonSubquery() { + PlanTester planTester = new PlanTester(); + + String sql = "SELECT s1 FROM table1 where s1 > any (select s1 from table1)"; + + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + + PlanMatchPattern tableScan1 = + tableScan("testdb.table1", ImmutableList.of("s1"), ImmutableSet.of("s1")); + + PlanMatchPattern tableScan2 = tableScan("testdb.table1", ImmutableMap.of("s1_7", "s1")); + + PlanMatchPattern tableScan3 = tableScan("testdb.table1", ImmutableMap.of("s1_6", "s1")); + + // Verify full LogicalPlan + /* + * └──OutputNode + * └──ProjectNode + * └──FilterNode + * └──ProjectNode + * └──JoinNode + * |──TableScanNode + * ├──AggregationNode + * │ └──TableScanNode + + */ + assertPlan( + logicalQueryPlan, + output( + project( + anyTree( + project( + join( + JoinNode.JoinType.INNER, + builder -> + builder + .left(tableScan1) + .right( + aggregation( + singleGroupingSet(), + ImmutableMap.of( + Optional.of("min"), + aggregationFunction( + "min", ImmutableList.of("s1_7")), + Optional.of("count_all"), + aggregationFunction( + "count_all", ImmutableList.of("s1_7")), + Optional.of("count_non_null"), + aggregationFunction( + "count", ImmutableList.of("s1_7"))), + Collections.emptyList(), + Optional.empty(), + SINGLE, + tableScan2)))))))); + + // Verify DistributionPlan + assertPlan( + planTester.getFragmentPlan(0), + output( + project( + anyTree( + project( + join( + JoinNode.JoinType.INNER, + builder -> + builder + .left(collect(exchange(), tableScan1, exchange())) + .right( + aggregation( + singleGroupingSet(), + ImmutableMap.of( + Optional.of("min"), + aggregationFunction( + "min", ImmutableList.of("min_9")), + Optional.of("count_all"), + aggregationFunction( + "count_all", ImmutableList.of("count_all_10")), + Optional.of("count_non_null"), + aggregationFunction( + "count", ImmutableList.of("count"))), + Collections.emptyList(), + Optional.empty(), + FINAL, + collect( + exchange(), + aggregation( + singleGroupingSet(), + ImmutableMap.of( + Optional.of("min_9"), + aggregationFunction( + "min", ImmutableList.of("s1_6")), + Optional.of("count_all_10"), + aggregationFunction( + "count_all", ImmutableList.of("s1_6")), + Optional.of("count"), + aggregationFunction( + "count", ImmutableList.of("s1_6"))), + Collections.emptyList(), + Optional.empty(), + PARTIAL, + tableScan3), + exchange()))))))))); + + assertPlan(planTester.getFragmentPlan(1), tableScan1); + + assertPlan(planTester.getFragmentPlan(2), tableScan1); + + assertPlan( + planTester.getFragmentPlan(3), + aggregation( + singleGroupingSet(), + ImmutableMap.of( + Optional.of("min_9"), + aggregationFunction("min", ImmutableList.of("s1_6")), + Optional.of("count_all_10"), + aggregationFunction("count_all", ImmutableList.of("s1_6")), + Optional.of("count"), + aggregationFunction("count", ImmutableList.of("s1_6"))), + Collections.emptyList(), + Optional.empty(), + PARTIAL, + tableScan3)); + + assertPlan( + planTester.getFragmentPlan(4), + aggregation( + singleGroupingSet(), + ImmutableMap.of( + Optional.of("min_9"), + aggregationFunction("min", ImmutableList.of("s1_6")), + Optional.of("count_all_10"), + aggregationFunction("count_all", ImmutableList.of("s1_6")), + Optional.of("count"), + aggregationFunction("count", ImmutableList.of("s1_6"))), + Collections.emptyList(), + Optional.empty(), + PARTIAL, + tableScan3)); + } + + @Test + public void testUncorrelatedEqualsSomeComparisonSubquery() { + PlanTester planTester = new PlanTester(); + + String sql = "SELECT s1 FROM table1 where s1 = some (select s1 from table1)"; + + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + + Expression filterPredicate = new SymbolReference("expr"); + + PlanMatchPattern tableScan1 = + tableScan("testdb.table1", ImmutableList.of("s1"), ImmutableSet.of("s1")); + + PlanMatchPattern tableScan2 = tableScan("testdb.table1", ImmutableMap.of("s1_6", "s1")); + + // Verify full LogicalPlan + /* + * └──OutputNode + * └──ProjectNode + * └──FilterNode + * └──SemiJoinNode + * |──SortNode + * | └──TableScanNode + * ├──SortNode + * │ └──TableScanNode + + */ + assertPlan( + logicalQueryPlan, + output( + project( + filter( + filterPredicate, + semiJoin("s1", "s1_6", "expr", sort(tableScan1), sort(tableScan2)))))); + } + + @Test + public void testUncorrelatedAllComparisonSubquery() { + PlanTester planTester = new PlanTester(); + + String sql = "SELECT s1 FROM table1 where s1 != all (select s1 from table1)"; + + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + + PlanMatchPattern tableScan1 = + tableScan("testdb.table1", ImmutableList.of("s1"), ImmutableSet.of("s1")); + + PlanMatchPattern tableScan2 = tableScan("testdb.table1", ImmutableMap.of("s1_6", "s1")); + + // Verify full LogicalPlan + /* + * └──OutputNode + * └──ProjectNode + * └──FilterNode + * └──ProjectNode + * └──SemiJoinNode + * |──SortNode + * | └──TableScanNode + * ├──SortNode + * │ └──TableScanNode + + */ + assertPlan( + logicalQueryPlan, + output( + project( + anyTree( + project(semiJoin("s1", "s1_6", "expr", sort(tableScan1), sort(tableScan2))))))); + } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java index 10aa13ed4ad9..3d0510957d13 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java @@ -38,6 +38,7 @@ public enum TableBuiltinAggregationFunction { SUM("sum"), COUNT("count"), + COUNT_ALL("count_all"), COUNT_IF("count_if"), AVG("avg"), EXTREME("extreme"), @@ -82,6 +83,7 @@ public static Type getIntermediateType(String name, List originalArgumentT final String functionName = name.toLowerCase(); switch (functionName) { case "count": + case "count_all": case "count_if": return INT64; case "sum": diff --git a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift index c46c4c0a65a0..93eafedc118e 100644 --- a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift +++ b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift @@ -281,7 +281,8 @@ enum TAggregationType { FIRST_BY, LAST_BY, MIN, - MAX + MAX, + COUNT_ALL } struct TShowConfigurationTemplateResp {